Commit 45e5be84 authored by David Howells's avatar David Howells Committed by Jakub Kicinski
Browse files

tls/sw: Convert tls_sw_sendpage() to use MSG_SPLICE_PAGES



Convert tls_sw_sendpage() and tls_sw_sendpage_locked() to use sendmsg()
with MSG_SPLICE_PAGES rather than directly splicing in the pages itself.

[!] Note that tls_sw_sendpage_locked() appears to have the wrong locking
    upstream.  I think the caller will only hold the socket lock, but it
    should hold tls_ctx->tx_lock too.

This allows ->sendpage() to be replaced by something that can handle
multiple multipage folios in a single transaction.

Signed-off-by: default avatarDavid Howells <dhowells@redhat.com>
Reviewed-by: default avatarJakub Kicinski <kuba@kernel.org>
cc: Chuck Lever <chuck.lever@oracle.com>
cc: Boris Pismenny <borisp@nvidia.com>
cc: John Fastabend <john.fastabend@gmail.com>
cc: Jens Axboe <axboe@kernel.dk>
cc: Matthew Wilcox <willy@infradead.org>
Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent fe1e81d4
Loading
Loading
Loading
Loading
+35 −138
Original line number Diff line number Diff line
@@ -960,7 +960,8 @@ static int tls_sw_sendmsg_splice(struct sock *sk, struct msghdr *msg,
	return 0;
}

int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
				 size_t size)
{
	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
	struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -983,15 +984,6 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
	int ret = 0;
	int pending;

	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
			       MSG_CMSG_COMPAT | MSG_SPLICE_PAGES))
		return -EOPNOTSUPP;

	ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
	if (ret)
		return ret;
	lock_sock(sk);

	if (unlikely(msg->msg_controllen)) {
		ret = tls_process_cmsg(sk, msg, &record_type);
		if (ret) {
@@ -1192,10 +1184,27 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)

send_end:
	ret = sk_stream_error(sk, msg->msg_flags, ret);
	return copied > 0 ? copied : ret;
}

int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	int ret;

	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
			       MSG_CMSG_COMPAT | MSG_SPLICE_PAGES |
			       MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
		return -EOPNOTSUPP;

	ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
	if (ret)
		return ret;
	lock_sock(sk);
	ret = tls_sw_sendmsg_locked(sk, msg, size);
	release_sock(sk);
	mutex_unlock(&tls_ctx->tx_lock);
	return copied > 0 ? copied : ret;
	return ret;
}

/*
@@ -1272,151 +1281,39 @@ void tls_sw_splice_eof(struct socket *sock)
	mutex_unlock(&tls_ctx->tx_lock);
}

static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
			      int offset, size_t size, int flags)
{
	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
	unsigned char record_type = TLS_RECORD_TYPE_DATA;
	struct sk_msg *msg_pl;
	struct tls_rec *rec;
	int num_async = 0;
	ssize_t copied = 0;
	bool full_record;
	int record_room;
	int ret = 0;
	bool eor;

	eor = !(flags & MSG_SENDPAGE_NOTLAST);
	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);

	/* Call the sk_stream functions to manage the sndbuf mem. */
	while (size > 0) {
		size_t copy, required_size;

		if (sk->sk_err) {
			ret = -sk->sk_err;
			goto sendpage_end;
		}

		if (ctx->open_rec)
			rec = ctx->open_rec;
		else
			rec = ctx->open_rec = tls_get_rec(sk);
		if (!rec) {
			ret = -ENOMEM;
			goto sendpage_end;
		}

		msg_pl = &rec->msg_plaintext;

		full_record = false;
		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
		copy = size;
		if (copy >= record_room) {
			copy = record_room;
			full_record = true;
		}

		required_size = msg_pl->sg.size + copy + prot->overhead_size;

		if (!sk_stream_memory_free(sk))
			goto wait_for_sndbuf;
alloc_payload:
		ret = tls_alloc_encrypted_msg(sk, required_size);
		if (ret) {
			if (ret != -ENOSPC)
				goto wait_for_memory;

			/* Adjust copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
			copy -= required_size - msg_pl->sg.size;
			full_record = true;
		}

		sk_msg_page_add(msg_pl, page, copy, offset);
		sk_mem_charge(sk, copy);

		offset += copy;
		size -= copy;
		copied += copy;

		tls_ctx->pending_open_record_frags = true;
		if (full_record || eor || sk_msg_full(msg_pl)) {
			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
						  record_type, &copied, flags);
			if (ret) {
				if (ret == -EINPROGRESS)
					num_async++;
				else if (ret == -ENOMEM)
					goto wait_for_memory;
				else if (ret != -EAGAIN) {
					if (ret == -ENOSPC)
						ret = 0;
					goto sendpage_end;
				}
			}
		}
		continue;
wait_for_sndbuf:
		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
wait_for_memory:
		ret = sk_stream_wait_memory(sk, &timeo);
		if (ret) {
			if (ctx->open_rec)
				tls_trim_both_msgs(sk, msg_pl->sg.size);
			goto sendpage_end;
		}

		if (ctx->open_rec)
			goto alloc_payload;
	}

	if (num_async) {
		/* Transmit if any encryptions have completed */
		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
			cancel_delayed_work(&ctx->tx_work.work);
			tls_tx_records(sk, flags);
		}
	}
sendpage_end:
	ret = sk_stream_error(sk, flags, ret);
	return copied > 0 ? copied : ret;
}

int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
			   int offset, size_t size, int flags)
{
	struct bio_vec bvec;
	struct msghdr msg = { .msg_flags = flags | MSG_SPLICE_PAGES, };

	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
		      MSG_NO_SHARED_FRAGS))
		return -EOPNOTSUPP;
	if (flags & MSG_SENDPAGE_NOTLAST)
		msg.msg_flags |= MSG_MORE;

	return tls_sw_do_sendpage(sk, page, offset, size, flags);
	bvec_set_page(&bvec, page, size, offset);
	iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size);
	return tls_sw_sendmsg_locked(sk, &msg, size);
}

int tls_sw_sendpage(struct sock *sk, struct page *page,
		    int offset, size_t size, int flags)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	int ret;
	struct bio_vec bvec;
	struct msghdr msg = { .msg_flags = flags | MSG_SPLICE_PAGES, };

	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
		return -EOPNOTSUPP;
	if (flags & MSG_SENDPAGE_NOTLAST)
		msg.msg_flags |= MSG_MORE;

	ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
	if (ret)
		return ret;
	lock_sock(sk);
	ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
	release_sock(sk);
	mutex_unlock(&tls_ctx->tx_lock);
	return ret;
	bvec_set_page(&bvec, page, size, offset);
	iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size);
	return tls_sw_sendmsg(sk, &msg, size);
}

static int