Commit ddb8701d authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

Merge branch 'splice-net-handle-msg_splice_pages-in-af_kcm'

David Howells says:

====================
splice, net: Handle MSG_SPLICE_PAGES in AF_KCM

Here are patches to make AF_KCM handle the MSG_SPLICE_PAGES internal
sendmsg flag.  MSG_SPLICE_PAGES is an internal hint that tells the protocol
that it should splice the pages supplied if it can.  Its sendpage
implementation is then turned into a wrapper around that.

Does anyone actually use AF_KCM?  Upstream it has some issues.  It doesn't
seem able to handle a "message" longer than 113920 bytes without jamming
and doesn't handle the client termination once it is jammed.

Link: https://git.kernel.org/pub/scm/linux/kernel/git/netdev/net-next.git/commit/?id=51c78a4d532efe9543a4df019ff405f05c6157f6 # part 1
Link: https://lore.kernel.org/r/20230524144923.3623536-1-dhowells@redhat.com/ # v1
====================

Link: https://lore.kernel.org/r/20230531110423.643196-1-dhowells@redhat.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 28cfea98 5bb3a5cb
Loading
Loading
Loading
Loading
+58 −160
Original line number Diff line number Diff line
@@ -761,149 +761,6 @@ static void kcm_push(struct kcm_sock *kcm)
		kcm_write_msgs(kcm);
}

static ssize_t kcm_sendpage(struct socket *sock, struct page *page,
			    int offset, size_t size, int flags)

{
	struct sock *sk = sock->sk;
	struct kcm_sock *kcm = kcm_sk(sk);
	struct sk_buff *skb = NULL, *head = NULL;
	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
	bool eor;
	int err = 0;
	int i;

	if (flags & MSG_SENDPAGE_NOTLAST)
		flags |= MSG_MORE;

	/* No MSG_EOR from splice, only look at MSG_MORE */
	eor = !(flags & MSG_MORE);

	lock_sock(sk);

	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);

	err = -EPIPE;
	if (sk->sk_err)
		goto out_error;

	if (kcm->seq_skb) {
		/* Previously opened message */
		head = kcm->seq_skb;
		skb = kcm_tx_msg(head)->last_skb;
		i = skb_shinfo(skb)->nr_frags;

		if (skb_can_coalesce(skb, i, page, offset)) {
			skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], size);
			skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG;
			goto coalesced;
		}

		if (i >= MAX_SKB_FRAGS) {
			struct sk_buff *tskb;

			tskb = alloc_skb(0, sk->sk_allocation);
			while (!tskb) {
				kcm_push(kcm);
				err = sk_stream_wait_memory(sk, &timeo);
				if (err)
					goto out_error;
			}

			if (head == skb)
				skb_shinfo(head)->frag_list = tskb;
			else
				skb->next = tskb;

			skb = tskb;
			skb->ip_summed = CHECKSUM_UNNECESSARY;
			i = 0;
		}
	} else {
		/* Call the sk_stream functions to manage the sndbuf mem. */
		if (!sk_stream_memory_free(sk)) {
			kcm_push(kcm);
			set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
			err = sk_stream_wait_memory(sk, &timeo);
			if (err)
				goto out_error;
		}

		head = alloc_skb(0, sk->sk_allocation);
		while (!head) {
			kcm_push(kcm);
			err = sk_stream_wait_memory(sk, &timeo);
			if (err)
				goto out_error;
		}

		skb = head;
		i = 0;
	}

	get_page(page);
	skb_fill_page_desc_noacc(skb, i, page, offset, size);
	skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG;

coalesced:
	skb->len += size;
	skb->data_len += size;
	skb->truesize += size;
	sk->sk_wmem_queued += size;
	sk_mem_charge(sk, size);

	if (head != skb) {
		head->len += size;
		head->data_len += size;
		head->truesize += size;
	}

	if (eor) {
		bool not_busy = skb_queue_empty(&sk->sk_write_queue);

		/* Message complete, queue it on send buffer */
		__skb_queue_tail(&sk->sk_write_queue, head);
		kcm->seq_skb = NULL;
		KCM_STATS_INCR(kcm->stats.tx_msgs);

		if (flags & MSG_BATCH) {
			kcm->tx_wait_more = true;
		} else if (kcm->tx_wait_more || not_busy) {
			err = kcm_write_msgs(kcm);
			if (err < 0) {
				/* We got a hard error in write_msgs but have
				 * already queued this message. Report an error
				 * in the socket, but don't affect return value
				 * from sendmsg
				 */
				pr_warn("KCM: Hard failure on kcm_write_msgs\n");
				report_csk_error(&kcm->sk, -err);
			}
		}
	} else {
		/* Message not complete, save state */
		kcm->seq_skb = head;
		kcm_tx_msg(head)->last_skb = skb;
	}

	KCM_STATS_ADD(kcm->stats.tx_bytes, size);

	release_sock(sk);
	return size;

out_error:
	kcm_push(kcm);

	err = sk_stream_error(sk, flags, err);

	/* make sure we wake any epoll edge trigger waiter */
	if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN))
		sk->sk_write_space(sk);

	release_sock(sk);
	return err;
}

static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
{
	struct sock *sk = sock->sk;
@@ -989,9 +846,29 @@ static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
			merge = false;
		}

		if (msg->msg_flags & MSG_SPLICE_PAGES) {
			copy = msg_data_left(msg);
			if (!sk_wmem_schedule(sk, copy))
				goto wait_for_memory;

			err = skb_splice_from_iter(skb, &msg->msg_iter, copy,
						   sk->sk_allocation);
			if (err < 0) {
				if (err == -EMSGSIZE)
					goto wait_for_memory;
				goto out_error;
			}

			copy = err;
			skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG;
			sk_wmem_queued_add(sk, copy);
			sk_mem_charge(sk, copy);

			if (head != skb)
				head->truesize += copy;
		} else {
			copy = min_t(int, msg_data_left(msg),
				     pfrag->size - pfrag->offset);

			if (!sk_wmem_schedule(sk, copy))
				goto wait_for_memory;

@@ -1004,7 +881,8 @@ static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)

			/* Update the skb. */
			if (merge) {
			skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy);
				skb_frag_size_add(
					&skb_shinfo(skb)->frags[i - 1], copy);
			} else {
				skb_fill_page_desc(skb, i, pfrag->page,
						   pfrag->offset, copy);
@@ -1012,6 +890,8 @@ static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
			}

			pfrag->offset += copy;
		}

		copied += copy;
		if (head != skb) {
			head->len += copy;
@@ -1088,6 +968,24 @@ static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
	return err;
}

static ssize_t kcm_sendpage(struct socket *sock, struct page *page,
			    int offset, size_t size, int flags)

{
	struct bio_vec bvec;
	struct msghdr msg = { .msg_flags = flags | MSG_SPLICE_PAGES, };

	if (flags & MSG_SENDPAGE_NOTLAST)
		msg.msg_flags |= MSG_MORE;

	if (flags & MSG_OOB)
		return -EOPNOTSUPP;

	bvec_set_page(&bvec, page, size, offset);
	iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size);
	return kcm_sendmsg(sock, &msg, size);
}

static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
		       size_t len, int flags)
{