Commit 8f1c3850 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'tls-rx-refactor-part-3'



Jakub Kicinski says:

====================
tls: rx: random refactoring part 3

TLS Rx refactoring. Part 3 of 3. This set is mostly around rx_list
and async processing. The last two patches are minor optimizations.
A couple of features to follow.
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents f45ba67e a4ae58cd
Loading
Loading
Loading
Loading
+60 −71
Original line number Diff line number Diff line
@@ -188,18 +188,13 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
		tls_err_abort(skb->sk, err);
	} else {
		struct strp_msg *rxm = strp_msg(skb);
		int pad;

		pad = padding_length(prot, skb);
		if (pad < 0) {
			ctx->async_wait.err = pad;
			tls_err_abort(skb->sk, pad);
		} else {
			rxm->full_len -= pad;
		/* No TLS 1.3 support with async crypto */
		WARN_ON(prot->tail_size);

		rxm->offset += prot->prepend_size;
		rxm->full_len -= prot->overhead_size;
	}
	}

	/* After using skb->sk to propagate sk through crypto async callback
	 * we need to NULL it again.
@@ -232,7 +227,7 @@ static int tls_do_decryption(struct sock *sk,
			     char *iv_recv,
			     size_t data_len,
			     struct aead_request *aead_req,
			     bool async)
			     struct tls_decrypt_arg *darg)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
@@ -245,7 +240,7 @@ static int tls_do_decryption(struct sock *sk,
			       data_len + prot->tag_size,
			       (u8 *)iv_recv);

	if (async) {
	if (darg->async) {
		/* Using skb->sk to push sk through to crypto async callback
		 * handler. This allows propagating errors up to the socket
		 * if needed. It _must_ be cleared in the async handler
@@ -265,14 +260,15 @@ static int tls_do_decryption(struct sock *sk,

	ret = crypto_aead_decrypt(aead_req);
	if (ret == -EINPROGRESS) {
		if (async)
			return ret;
		if (darg->async)
			return 0;

		ret = crypto_wait_req(ret, &ctx->async_wait);
	}
	darg->async = false;

	if (async)
		atomic_dec(&ctx->decrypt_pending);
	if (ret == -EBADMSG)
		TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);

	return ret;
}
@@ -1456,7 +1452,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
	mem_size = mem_size + prot->aad_size;
	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
	mem_size = mem_size + MAX_IV_SIZE;

	/* Allocate a single block of memory which contains
	 * aead_req || sgin[] || sgout[] || aad || iv.
@@ -1486,6 +1482,11 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
	}

	/* Prepare IV */
	if (prot->version == TLS_1_3_VERSION ||
	    prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
		memcpy(iv + iv_offset, tls_ctx->rx.iv,
		       prot->iv_size + prot->salt_size);
	} else {
		err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
				    iv + iv_offset + prot->salt_size,
				    prot->iv_size);
@@ -1493,13 +1494,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
			kfree(mem);
			return err;
		}
	if (prot->version == TLS_1_3_VERSION ||
	    prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305)
		memcpy(iv + iv_offset, tls_ctx->rx.iv,
		       prot->iv_size + prot->salt_size);
	else
		memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);

	}
	xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);

	/* Prepare AAD */
@@ -1542,9 +1538,9 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,

	/* Prepare and submit AEAD request */
	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
				data_len, aead_req, darg->async);
	if (err == -EINPROGRESS)
		return err;
				data_len, aead_req, darg);
	if (darg->async)
		return 0;

	/* Release the pages in case iov was mapped to pages */
	for (; pages > 0; pages--)
@@ -1581,13 +1577,10 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
	}

	err = decrypt_internal(sk, skb, dest, NULL, darg);
	if (err < 0) {
		if (err == -EINPROGRESS)
			tls_advance_record_sn(sk, prot, &tls_ctx->rx);
		else if (err == -EBADMSG)
			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
	if (err < 0)
		return err;
	}
	if (darg->async)
		goto decrypt_next;

decrypt_done:
	pad = padding_length(prot, skb);
@@ -1597,8 +1590,9 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
	rxm->full_len -= pad;
	rxm->offset += prot->prepend_size;
	rxm->full_len -= prot->overhead_size;
	tls_advance_record_sn(sk, prot, &tls_ctx->rx);
	tlm->decrypted = 1;
decrypt_next:
	tls_advance_record_sn(sk, prot, &tls_ctx->rx);

	return 0;
}
@@ -1658,7 +1652,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,

		err = tls_record_content_type(msg, tlm, control);
		if (err <= 0)
			return err;
			goto out;

		if (skip < rxm->full_len)
			break;
@@ -1676,13 +1670,13 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,

		err = tls_record_content_type(msg, tlm, control);
		if (err <= 0)
			return err;
			goto out;

		if (!zc || (rxm->full_len - skip) > len) {
			err = skb_copy_datagram_msg(skb, rxm->offset + skip,
						    msg, chunk);
			if (err < 0)
				return err;
				goto out;
		}

		len = len - chunk;
@@ -1709,14 +1703,16 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
		next_skb = skb_peek_next(skb, &ctx->rx_list);

		if (!is_peek) {
			skb_unlink(skb, &ctx->rx_list);
			__skb_unlink(skb, &ctx->rx_list);
			consume_skb(skb);
		}

		skb = next_skb;
	}
	err = 0;

	return copied;
out:
	return copied ? : err;
}

int tls_sw_recvmsg(struct sock *sk,
@@ -1750,12 +1746,15 @@ int tls_sw_recvmsg(struct sock *sk,
	lock_sock(sk);
	bpf_strp_enabled = sk_psock_strp_enabled(psock);

	/* If crypto failed the connection is broken */
	err = ctx->async_wait.err;
	if (err)
		goto end;

	/* Process pending decrypted records. It must be non-zero-copy */
	err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
	if (err < 0) {
		tls_err_abort(sk, err);
	if (err < 0)
		goto end;
	}

	copied = err;
	if (len <= copied)
@@ -1775,14 +1774,10 @@ int tls_sw_recvmsg(struct sock *sk,
		skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
		if (!skb) {
			if (psock) {
				int ret = sk_msg_recvmsg(sk, psock, msg, len,
				chunk = sk_msg_recvmsg(sk, psock, msg, len,
						       flags);

				if (ret > 0) {
					decrypted += ret;
					len -= ret;
					continue;
				}
				if (chunk > 0)
					goto leave_on_list;
			}
			goto recv_end;
		}
@@ -1803,13 +1798,12 @@ int tls_sw_recvmsg(struct sock *sk,
			darg.async = false;

		err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
		if (err < 0 && err != -EINPROGRESS) {
		if (err < 0) {
			tls_err_abort(sk, -EBADMSG);
			goto recv_end;
		}

		if (err == -EINPROGRESS)
			async = true;
		async |= darg.async;

		/* If the type of records being processed is not known yet,
		 * set it to record type just dequeued. If it is already known,
@@ -1824,7 +1818,7 @@ int tls_sw_recvmsg(struct sock *sk,

		ctx->recv_pkt = NULL;
		__strp_unpause(&ctx->strp);
		skb_queue_tail(&ctx->rx_list, skb);
		__skb_queue_tail(&ctx->rx_list, skb);

		if (async) {
			/* TLS 1.2-only, to_decrypt must be text length */
@@ -1845,7 +1839,7 @@ int tls_sw_recvmsg(struct sock *sk,
				if (err != __SK_PASS) {
					rxm->offset = rxm->offset + rxm->full_len;
					rxm->full_len = 0;
					skb_unlink(skb, &ctx->rx_list);
					__skb_unlink(skb, &ctx->rx_list);
					if (err == __SK_DROP)
						consume_skb(skb);
					continue;
@@ -1873,7 +1867,7 @@ int tls_sw_recvmsg(struct sock *sk,
		decrypted += chunk;
		len -= chunk;

		skb_unlink(skb, &ctx->rx_list);
		__skb_unlink(skb, &ctx->rx_list);
		consume_skb(skb);

		/* Return full control message to userspace before trying
@@ -1886,7 +1880,7 @@ int tls_sw_recvmsg(struct sock *sk,

recv_end:
	if (async) {
		int pending;
		int ret, pending;

		/* Wait for all previously submitted records to be decrypted */
		spin_lock_bh(&ctx->decrypt_compl_lock);
@@ -1894,11 +1888,10 @@ int tls_sw_recvmsg(struct sock *sk,
		pending = atomic_read(&ctx->decrypt_pending);
		spin_unlock_bh(&ctx->decrypt_compl_lock);
		if (pending) {
			err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
			if (err) {
				/* one of async decrypt failed */
				tls_err_abort(sk, err);
				copied = 0;
			ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
			if (ret) {
				if (err >= 0 || err == -EINPROGRESS)
					err = ret;
				decrypted = 0;
				goto end;
			}
@@ -1911,11 +1904,7 @@ int tls_sw_recvmsg(struct sock *sk,
		else
			err = process_rx_list(ctx, msg, &control, 0,
					      decrypted, true, is_peek);
		if (err < 0) {
			tls_err_abort(sk, err);
			copied = 0;
			goto end;
		}
		decrypted = max(err, 0);
	}

	copied += decrypted;
@@ -2173,7 +2162,7 @@ void tls_sw_release_resources_rx(struct sock *sk)
	if (ctx->aead_recv) {
		kfree_skb(ctx->recv_pkt);
		ctx->recv_pkt = NULL;
		skb_queue_purge(&ctx->rx_list);
		__skb_queue_purge(&ctx->rx_list);
		crypto_free_aead(ctx->aead_recv);
		strp_stop(&ctx->strp);
		/* If tls_sw_strparser_arm() was not called (cleanup paths)