Commit 516a2f1f authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'tls-rx-refactoring-part-2'

Jakub Kicinski says:

====================
tls: rx: random refactoring part 2

TLS Rx refactoring. Part 2 of 3. This one focusing on the main loop.
A couple of features to follow.
====================
parents 626a5aaa f940b6ef
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -152,7 +152,6 @@ struct tls_sw_context_rx {
	atomic_t decrypt_pending;
	/* protect crypto_wait with decrypt_pending*/
	spinlock_t decrypt_compl_lock;
	bool async_notify;
};

struct tls_record_info {
+104 −160
Original line number Diff line number Diff line
@@ -44,6 +44,11 @@
#include <net/strparser.h>
#include <net/tls.h>

struct tls_decrypt_arg {
	bool zc;
	bool async;
};

noinline void tls_err_abort(struct sock *sk, int err)
{
	WARN_ON_ONCE(err >= 0);
@@ -168,7 +173,6 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
	struct scatterlist *sg;
	struct sk_buff *skb;
	unsigned int pages;
	int pending;

	skb = (struct sk_buff *)req->data;
	tls_ctx = tls_get_ctx(skb->sk);
@@ -216,9 +220,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
	kfree(aead_req);

	spin_lock_bh(&ctx->decrypt_compl_lock);
	pending = atomic_dec_return(&ctx->decrypt_pending);

	if (!pending && ctx->async_notify)
	if (!atomic_dec_return(&ctx->decrypt_pending))
		complete(&ctx->async_wait.completion);
	spin_unlock_bh(&ctx->decrypt_compl_lock);
}
@@ -1345,15 +1347,14 @@ static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
	return skb;
}

static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
static int tls_setup_from_iter(struct iov_iter *from,
			       int length, int *pages_used,
			       unsigned int *size_used,
			       struct scatterlist *to,
			       int to_max_pages)
{
	int rc = 0, i = 0, num_elem = *pages_used, maxpages;
	struct page *pages[MAX_SKB_FRAGS];
	unsigned int size = *size_used;
	unsigned int size = 0;
	ssize_t copied, use;
	size_t offset;

@@ -1396,8 +1397,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
		sg_mark_end(&to[num_elem - 1]);
out:
	if (rc)
		iov_iter_revert(from, size - *size_used);
	*size_used = size;
		iov_iter_revert(from, size);
	*pages_used = num_elem;

	return rc;
@@ -1414,7 +1414,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
			    struct iov_iter *out_iov,
			    struct scatterlist *out_sg,
			    int *chunk, bool *zc, bool async)
			    struct tls_decrypt_arg *darg)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1431,7 +1431,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
			     prot->tail_size;
	int iv_offset = 0;

	if (*zc && (out_iov || out_sg)) {
	if (darg->zc && (out_iov || out_sg)) {
		if (out_iov)
			n_sgout = 1 +
				iov_iter_npages_cap(out_iov, INT_MAX, data_len);
@@ -1441,7 +1441,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
				 rxm->full_len - prot->prepend_size);
	} else {
		n_sgout = 0;
		*zc = false;
		darg->zc = false;
		n_sgin = skb_cow_data(skb, 0, &unused);
	}

@@ -1523,9 +1523,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
			sg_init_table(sgout, n_sgout);
			sg_set_buf(&sgout[0], aad, prot->aad_size);

			*chunk = 0;
			err = tls_setup_from_iter(sk, out_iov, data_len,
						  &pages, chunk, &sgout[1],
			err = tls_setup_from_iter(out_iov, data_len,
						  &pages, &sgout[1],
						  (n_sgout - 1));
			if (err < 0)
				goto fallback_to_reg_recv;
@@ -1538,13 +1537,12 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
fallback_to_reg_recv:
		sgout = sgin;
		pages = 0;
		*chunk = data_len;
		*zc = false;
		darg->zc = false;
	}

	/* Prepare and submit AEAD request */
	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
				data_len, aead_req, async);
				data_len, aead_req, darg->async);
	if (err == -EINPROGRESS)
		return err;

@@ -1557,8 +1555,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
}

static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
			      struct iov_iter *dest, int *chunk, bool *zc,
			      bool async)
			      struct iov_iter *dest,
			      struct tls_decrypt_arg *darg)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
@@ -1567,7 +1565,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
	int pad, err;

	if (tlm->decrypted) {
		*zc = false;
		darg->zc = false;
		return 0;
	}

@@ -1577,12 +1575,12 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
			return err;
		if (err > 0) {
			tlm->decrypted = 1;
			*zc = false;
			darg->zc = false;
			goto decrypt_done;
		}
	}

	err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
	err = decrypt_internal(sk, skb, dest, NULL, darg);
	if (err < 0) {
		if (err == -EINPROGRESS)
			tls_advance_record_sn(sk, prot, &tls_ctx->rx);
@@ -1608,34 +1606,32 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
		struct scatterlist *sgout)
{
	bool zc = true;
	int chunk;
	struct tls_decrypt_arg darg = { .zc = true, };

	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
	return decrypt_internal(sk, skb, NULL, sgout, &darg);
}

static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
			       unsigned int len)
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
				   u8 *control)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	int err;

	if (skb) {
		struct strp_msg *rxm = strp_msg(skb);
	if (!*control) {
		*control = tlm->control;
		if (!*control)
			return -EBADMSG;

		if (len < rxm->full_len) {
			rxm->offset += len;
			rxm->full_len -= len;
			return false;
		err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
			       sizeof(*control), control);
		if (*control != TLS_RECORD_TYPE_DATA) {
			if (err || msg->msg_flags & MSG_CTRUNC)
				return -EIO;
		}
		consume_skb(skb);
	} else if (*control != tlm->control) {
		return 0;
	}

	/* Finished with message */
	ctx->recv_pkt = NULL;
	__strp_unpause(&ctx->strp);

	return true;
	return 1;
}

/* This function traverses the rx_list in tls receive context to copies the
@@ -1646,31 +1642,23 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
static int process_rx_list(struct tls_sw_context_rx *ctx,
			   struct msghdr *msg,
			   u8 *control,
			   bool *cmsg,
			   size_t skip,
			   size_t len,
			   bool zc,
			   bool is_peek)
{
	struct sk_buff *skb = skb_peek(&ctx->rx_list);
	u8 ctrl = *control;
	u8 msgc = *cmsg;
	struct tls_msg *tlm;
	ssize_t copied = 0;

	/* Set the record type in 'control' if caller didn't pass it */
	if (!ctrl && skb) {
		tlm = tls_msg(skb);
		ctrl = tlm->control;
	}
	int err;

	while (skip && skb) {
		struct strp_msg *rxm = strp_msg(skb);
		tlm = tls_msg(skb);

		/* Cannot process a record of different type */
		if (ctrl != tlm->control)
			return 0;
		err = tls_record_content_type(msg, tlm, control);
		if (err <= 0)
			return err;

		if (skip < rxm->full_len)
			break;
@@ -1686,27 +1674,12 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,

		tlm = tls_msg(skb);

		/* Cannot process a record of different type */
		if (ctrl != tlm->control)
			return 0;

		/* Set record type if not already done. For a non-data record,
		 * do not proceed if record type could not be copied.
		 */
		if (!msgc) {
			int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
					    sizeof(ctrl), &ctrl);
			msgc = true;
			if (ctrl != TLS_RECORD_TYPE_DATA) {
				if (cerr || msg->msg_flags & MSG_CTRUNC)
					return -EIO;

				*cmsg = msgc;
			}
		}
		err = tls_record_content_type(msg, tlm, control);
		if (err <= 0)
			return err;

		if (!zc || (rxm->full_len - skip) > len) {
			int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
			err = skb_copy_datagram_msg(skb, rxm->offset + skip,
						    msg, chunk);
			if (err < 0)
				return err;
@@ -1743,7 +1716,6 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
		skb = next_skb;
	}

	*control = ctrl;
	return copied;
}

@@ -1758,19 +1730,19 @@ int tls_sw_recvmsg(struct sock *sk,
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
	struct sk_psock *psock;
	int num_async, pending;
	unsigned char control = 0;
	ssize_t decrypted = 0;
	struct strp_msg *rxm;
	struct tls_msg *tlm;
	struct sk_buff *skb;
	ssize_t copied = 0;
	bool cmsg = false;
	bool async = false;
	int target, err = 0;
	long timeo;
	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
	bool is_peek = flags & MSG_PEEK;
	bool bpf_strp_enabled;
	bool zc_capable;

	flags |= nonblock;

@@ -1782,8 +1754,7 @@ int tls_sw_recvmsg(struct sock *sk,
	bpf_strp_enabled = sk_psock_strp_enabled(psock);

	/* Process pending decrypted records. It must be non-zero-copy */
	err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
			      is_peek);
	err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
	if (err < 0) {
		tls_err_abort(sk, err);
		goto end;
@@ -1797,15 +1768,12 @@ int tls_sw_recvmsg(struct sock *sk,
	len = len - copied;
	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);

	zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
		     prot->version != TLS_1_3_VERSION;
	decrypted = 0;
	num_async = 0;
	while (len && (decrypted + copied < target || ctx->recv_pkt)) {
		bool retain_skb = false;
		bool zc = false;
		int to_decrypt;
		int chunk = 0;
		bool async_capable;
		bool async = false;
		struct tls_decrypt_arg darg = {};
		int to_decrypt, chunk;

		skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
		if (!skb) {
@@ -1827,29 +1795,24 @@ int tls_sw_recvmsg(struct sock *sk,

		to_decrypt = rxm->full_len - prot->overhead_size;

		if (to_decrypt <= len && !is_kvec && !is_peek &&
		    tlm->control == TLS_RECORD_TYPE_DATA &&
		    prot->version != TLS_1_3_VERSION &&
		    !bpf_strp_enabled)
			zc = true;
		if (zc_capable && to_decrypt <= len &&
		    tlm->control == TLS_RECORD_TYPE_DATA)
			darg.zc = true;

		/* Do not use async mode if record is non-data */
		if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
			async_capable = ctx->async_capable;
			darg.async = ctx->async_capable;
		else
			async_capable = false;
			darg.async = false;

		err = decrypt_skb_update(sk, skb, &msg->msg_iter,
					 &chunk, &zc, async_capable);
		err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
		if (err < 0 && err != -EINPROGRESS) {
			tls_err_abort(sk, -EBADMSG);
			goto recv_end;
		}

		if (err == -EINPROGRESS) {
		if (err == -EINPROGRESS)
			async = true;
			num_async++;
		}

		/* If the type of records being processed is not known yet,
		 * set it to record type just dequeued. If it is already known,
@@ -1858,92 +1821,79 @@ int tls_sw_recvmsg(struct sock *sk,
		 * is known just after record is dequeued from stream parser.
		 * For tls1.3, we disable async.
		 */

		if (!control)
			control = tlm->control;
		else if (control != tlm->control)
		err = tls_record_content_type(msg, tlm, &control);
		if (err <= 0)
			goto recv_end;

		if (!cmsg) {
			int cerr;
		ctx->recv_pkt = NULL;
		__strp_unpause(&ctx->strp);
		skb_queue_tail(&ctx->rx_list, skb);

			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
					sizeof(control), &control);
			cmsg = true;
			if (control != TLS_RECORD_TYPE_DATA) {
				if (cerr || msg->msg_flags & MSG_CTRUNC) {
					err = -EIO;
					goto recv_end;
				}
			}
		if (async) {
			/* TLS 1.2-only, to_decrypt must be text length */
			chunk = min_t(int, to_decrypt, len);
leave_on_list:
			decrypted += chunk;
			len -= chunk;
			continue;
		}
		/* TLS 1.3 may have updated the length by more than overhead */
		chunk = rxm->full_len;

		if (async)
			goto pick_next_record;
		if (!darg.zc) {
			bool partially_consumed = chunk > len;

		if (!zc) {
			if (bpf_strp_enabled) {
				err = sk_psock_tls_strp_read(psock, skb);
				if (err != __SK_PASS) {
					rxm->offset = rxm->offset + rxm->full_len;
					rxm->full_len = 0;
					skb_unlink(skb, &ctx->rx_list);
					if (err == __SK_DROP)
						consume_skb(skb);
					ctx->recv_pkt = NULL;
					__strp_unpause(&ctx->strp);
					continue;
				}
			}

			if (rxm->full_len > len) {
				retain_skb = true;
			if (partially_consumed)
				chunk = len;
			} else {
				chunk = rxm->full_len;
			}

			err = skb_copy_datagram_msg(skb, rxm->offset,
						    msg, chunk);
			if (err < 0)
				goto recv_end;

			if (!is_peek) {
				rxm->offset = rxm->offset + chunk;
				rxm->full_len = rxm->full_len - chunk;
			if (is_peek)
				goto leave_on_list;

			if (partially_consumed) {
				rxm->offset += chunk;
				rxm->full_len -= chunk;
				goto leave_on_list;
			}
		}

pick_next_record:
		if (chunk > len)
			chunk = len;

		decrypted += chunk;
		len -= chunk;

		/* For async or peek case, queue the current skb */
		if (async || is_peek || retain_skb) {
			skb_queue_tail(&ctx->rx_list, skb);
			skb = NULL;
		}
		skb_unlink(skb, &ctx->rx_list);
		consume_skb(skb);

		if (tls_sw_advance_skb(sk, skb, chunk)) {
			/* Return full control message to
			 * userspace before trying to parse
			 * another message type
		/* Return full control message to userspace before trying
		 * to parse another message type
		 */
		msg->msg_flags |= MSG_EOR;
		if (control != TLS_RECORD_TYPE_DATA)
				goto recv_end;
		} else {
			break;
	}
	}

recv_end:
	if (num_async) {
	if (async) {
		int pending;

		/* Wait for all previously submitted records to be decrypted */
		spin_lock_bh(&ctx->decrypt_compl_lock);
		ctx->async_notify = true;
		reinit_completion(&ctx->async_wait.completion);
		pending = atomic_read(&ctx->decrypt_pending);
		spin_unlock_bh(&ctx->decrypt_compl_lock);
		if (pending) {
@@ -1955,21 +1905,14 @@ int tls_sw_recvmsg(struct sock *sk,
				decrypted = 0;
				goto end;
			}
		} else {
			reinit_completion(&ctx->async_wait.completion);
		}

		/* There can be no concurrent accesses, since we have no
		 * pending decrypt operations
		 */
		WRITE_ONCE(ctx->async_notify, false);

		/* Drain records from the rx_list & copy if required */
		if (is_peek || is_kvec)
			err = process_rx_list(ctx, msg, &control, &cmsg, copied,
			err = process_rx_list(ctx, msg, &control, copied,
					      decrypted, false, is_peek);
		else
			err = process_rx_list(ctx, msg, &control, &cmsg, 0,
			err = process_rx_list(ctx, msg, &control, 0,
					      decrypted, true, is_peek);
		if (err < 0) {
			tls_err_abort(sk, err);
@@ -2003,7 +1946,6 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
	int err = 0;
	long timeo;
	int chunk;
	bool zc = false;

	lock_sock(sk);

@@ -2013,12 +1955,14 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
	if (from_queue) {
		skb = __skb_dequeue(&ctx->rx_list);
	} else {
		struct tls_decrypt_arg darg = {};

		skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
				    &err);
		if (!skb)
			goto splice_read_end;

		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
		err = decrypt_skb_update(sk, skb, NULL, &darg);
		if (err < 0) {
			tls_err_abort(sk, -EBADMSG);
			goto splice_read_end;