Commit 6bd116c8 authored by Jakub Kicinski's avatar Jakub Kicinski Committed by David S. Miller
Browse files

tls: rx: return the decrypted skb via darg



Instead of using ctx->recv_pkt after decryption read the skb
from darg.skb. This moves the decision of what the "output skb"
is to the decrypt handlers. For now after decrypt handler returns
successfully ctx->recv_pkt is simply moved to darg.skb, but it
will change soon.

Note that tls_decrypt_sg() cannot clear the ctx->recv_pkt
because it gets called to re-encrypt (i.e. by the device offload).
So we need an awkward temporary if() in tls_rx_one_record().

Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 541cc48b
Loading
Loading
Loading
Loading
+39 −10
Original line number Diff line number Diff line
@@ -47,9 +47,13 @@
#include "tls.h"

struct tls_decrypt_arg {
	struct_group(inargs,
	bool zc;
	bool async;
	u8 tail;
	);

	struct sk_buff *skb;
};

struct tls_decrypt_ctx {
@@ -1412,6 +1416,7 @@ static int tls_setup_from_iter(struct iov_iter *from,
 * -------------------------------------------------------------------
 *    zc | Zero-copy decrypt allowed | Zero-copy performed
 * async | Async decrypt allowed     | Async crypto used / in progress
 *   skb |            *              | Output skb
 */

/* This function decrypts the input skb into either out_iov or in out_sg
@@ -1551,12 +1556,17 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
	/* Prepare and submit AEAD request */
	err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
				data_len + prot->tail_size, aead_req, darg);
	if (err)
		goto exit_free_pages;

	darg->skb = tls_strp_msg(ctx);
	if (darg->async)
		return 0;

	if (prot->tail_size)
		darg->tail = dctx->tail;

exit_free_pages:
	/* Release the pages in case iov was mapped to pages */
	for (; pages > 0; pages--)
		put_page(sg_page(&sgout[pages]));
@@ -1569,6 +1579,7 @@ static int
tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
		   struct tls_decrypt_arg *darg)
{
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	int err;

	if (tls_ctx->rx_conf != TLS_HW)
@@ -1580,6 +1591,8 @@ tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,

	darg->zc = false;
	darg->async = false;
	darg->skb = tls_strp_msg(ctx);
	ctx->recv_pkt = NULL;
	return 1;
}

@@ -1604,8 +1617,11 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
		return err;
	}
	if (darg->async)
	if (darg->async) {
		if (darg->skb == ctx->recv_pkt)
			ctx->recv_pkt = NULL;
		goto decrypt_next;
	}
	/* If opportunistic TLS 1.3 ZC failed retry without ZC */
	if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
		     darg->tail != TLS_RECORD_TYPE_DATA)) {
@@ -1616,12 +1632,17 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
		return tls_rx_one_record(sk, dest, darg);
	}

	if (darg->skb == ctx->recv_pkt)
		ctx->recv_pkt = NULL;

decrypt_done:
	pad = tls_padding_length(prot, ctx->recv_pkt, darg);
	if (pad < 0)
	pad = tls_padding_length(prot, darg->skb, darg);
	if (pad < 0) {
		consume_skb(darg->skb);
		return pad;
	}

	rxm = strp_msg(ctx->recv_pkt);
	rxm = strp_msg(darg->skb);
	rxm->full_len -= pad;
	rxm->offset += prot->prepend_size;
	rxm->full_len -= prot->overhead_size;
@@ -1663,6 +1684,7 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,

static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
{
	consume_skb(ctx->recv_pkt);
	ctx->recv_pkt = NULL;
	__strp_unpause(&ctx->strp);
}
@@ -1872,7 +1894,7 @@ int tls_sw_recvmsg(struct sock *sk,
		ctx->zc_capable;
	decrypted = 0;
	while (len && (decrypted + copied < target || ctx->recv_pkt)) {
		struct tls_decrypt_arg darg = {};
		struct tls_decrypt_arg darg;
		int to_decrypt, chunk;

		err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo);
@@ -1889,9 +1911,10 @@ int tls_sw_recvmsg(struct sock *sk,
			goto recv_end;
		}

		skb = ctx->recv_pkt;
		rxm = strp_msg(skb);
		tlm = tls_msg(skb);
		memset(&darg.inargs, 0, sizeof(darg.inargs));

		rxm = strp_msg(ctx->recv_pkt);
		tlm = tls_msg(ctx->recv_pkt);

		to_decrypt = rxm->full_len - prot->overhead_size;

@@ -1911,6 +1934,10 @@ int tls_sw_recvmsg(struct sock *sk,
			goto recv_end;
		}

		skb = darg.skb;
		rxm = strp_msg(skb);
		tlm = tls_msg(skb);

		async |= darg.async;

		/* If the type of records being processed is not known yet,
@@ -2051,21 +2078,23 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
	if (!skb_queue_empty(&ctx->rx_list)) {
		skb = __skb_dequeue(&ctx->rx_list);
	} else {
		struct tls_decrypt_arg darg = {};
		struct tls_decrypt_arg darg;

		err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
				      timeo);
		if (err <= 0)
			goto splice_read_end;

		memset(&darg.inargs, 0, sizeof(darg.inargs));

		err = tls_rx_one_record(sk, NULL, &darg);
		if (err < 0) {
			tls_err_abort(sk, -EBADMSG);
			goto splice_read_end;
		}

		skb = ctx->recv_pkt;
		tls_rx_rec_done(ctx);
		skb = darg.skb;
	}

	rxm = strp_msg(skb);