Commit 419ce133 authored by Paolo Abeni's avatar Paolo Abeni Committed by Jakub Kicinski
Browse files

tcp: allow again tcp_disconnect() when threads are waiting



As reported by Tom, .NET and applications build on top of it rely
on connect(AF_UNSPEC) to async cancel pending I/O operations on TCP
socket.

The blamed commit below caused a regression, as such cancellation
can now fail.

As suggested by Eric, this change addresses the problem explicitly
causing blocking I/O operation to terminate immediately (with an error)
when a concurrent disconnect() is executed.

Instead of tracking the number of threads blocked on a given socket,
track the number of disconnect() issued on such socket. If such counter
changes after a blocking operation releasing and re-acquiring the socket
lock, error out the current operation.

Fixes: 4faeee0c ("tcp: deny tcp_disconnect() when threads are waiting")
Reported-by: default avatarTom Deseyn <tdeseyn@redhat.com>
Closes: https://bugzilla.redhat.com/show_bug.cgi?id=1886305


Suggested-by: default avatarEric Dumazet <edumazet@google.com>
Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
Reviewed-by: default avatarEric Dumazet <edumazet@google.com>
Link: https://lore.kernel.org/r/f3b95e47e3dbed840960548aebaa8d954372db41.1697008693.git.pabeni@redhat.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent 242e3450
Loading
Loading
Loading
Loading
+29 −7
Original line number Original line Diff line number Diff line
@@ -911,7 +911,7 @@ static int csk_wait_memory(struct chtls_dev *cdev,
			   struct sock *sk, long *timeo_p)
			   struct sock *sk, long *timeo_p)
{
{
	DEFINE_WAIT_FUNC(wait, woken_wake_function);
	DEFINE_WAIT_FUNC(wait, woken_wake_function);
	int err = 0;
	int ret, err = 0;
	long current_timeo;
	long current_timeo;
	long vm_wait = 0;
	long vm_wait = 0;
	bool noblock;
	bool noblock;
@@ -942,10 +942,13 @@ static int csk_wait_memory(struct chtls_dev *cdev,


		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
		sk->sk_write_pending++;
		sk->sk_write_pending++;
		sk_wait_event(sk, &current_timeo, sk->sk_err ||
		ret = sk_wait_event(sk, &current_timeo, sk->sk_err ||
				    (sk->sk_shutdown & SEND_SHUTDOWN) ||
				    (sk->sk_shutdown & SEND_SHUTDOWN) ||
			      (csk_mem_free(cdev, sk) && !vm_wait), &wait);
				    (csk_mem_free(cdev, sk) && !vm_wait),
				    &wait);
		sk->sk_write_pending--;
		sk->sk_write_pending--;
		if (ret < 0)
			goto do_error;


		if (vm_wait) {
		if (vm_wait) {
			vm_wait -= current_timeo;
			vm_wait -= current_timeo;
@@ -1348,6 +1351,7 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
	int copied = 0;
	int copied = 0;
	int target;
	int target;
	long timeo;
	long timeo;
	int ret;


	buffers_freed = 0;
	buffers_freed = 0;


@@ -1423,7 +1427,11 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
		if (copied >= target)
		if (copied >= target)
			break;
			break;
		chtls_cleanup_rbuf(sk, copied);
		chtls_cleanup_rbuf(sk, copied);
		sk_wait_data(sk, &timeo, NULL);
		ret = sk_wait_data(sk, &timeo, NULL);
		if (ret < 0) {
			copied = copied ? : ret;
			goto unlock;
		}
		continue;
		continue;
found_ok_skb:
found_ok_skb:
		if (!skb->len) {
		if (!skb->len) {
@@ -1518,6 +1526,8 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,


	if (buffers_freed)
	if (buffers_freed)
		chtls_cleanup_rbuf(sk, copied);
		chtls_cleanup_rbuf(sk, copied);

unlock:
	release_sock(sk);
	release_sock(sk);
	return copied;
	return copied;
}
}
@@ -1534,6 +1544,7 @@ static int peekmsg(struct sock *sk, struct msghdr *msg,
	int copied = 0;
	int copied = 0;
	size_t avail;          /* amount of available data in current skb */
	size_t avail;          /* amount of available data in current skb */
	long timeo;
	long timeo;
	int ret;


	lock_sock(sk);
	lock_sock(sk);
	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
@@ -1585,7 +1596,12 @@ static int peekmsg(struct sock *sk, struct msghdr *msg,
			release_sock(sk);
			release_sock(sk);
			lock_sock(sk);
			lock_sock(sk);
		} else {
		} else {
			sk_wait_data(sk, &timeo, NULL);
			ret = sk_wait_data(sk, &timeo, NULL);
			if (ret < 0) {
				/* here 'copied' is 0 due to previous checks */
				copied = ret;
				break;
			}
		}
		}


		if (unlikely(peek_seq != tp->copied_seq)) {
		if (unlikely(peek_seq != tp->copied_seq)) {
@@ -1656,6 +1672,7 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
	int copied = 0;
	int copied = 0;
	long timeo;
	long timeo;
	int target;             /* Read at least this many bytes */
	int target;             /* Read at least this many bytes */
	int ret;


	buffers_freed = 0;
	buffers_freed = 0;


@@ -1747,7 +1764,11 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
		if (copied >= target)
		if (copied >= target)
			break;
			break;
		chtls_cleanup_rbuf(sk, copied);
		chtls_cleanup_rbuf(sk, copied);
		sk_wait_data(sk, &timeo, NULL);
		ret = sk_wait_data(sk, &timeo, NULL);
		if (ret < 0) {
			copied = copied ? : ret;
			goto unlock;
		}
		continue;
		continue;


found_ok_skb:
found_ok_skb:
@@ -1816,6 +1837,7 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
	if (buffers_freed)
	if (buffers_freed)
		chtls_cleanup_rbuf(sk, copied);
		chtls_cleanup_rbuf(sk, copied);


unlock:
	release_sock(sk);
	release_sock(sk);
	return copied;
	return copied;
}
}
+4 −6
Original line number Original line Diff line number Diff line
@@ -336,7 +336,7 @@ struct sk_filter;
  *	@sk_cgrp_data: cgroup data for this cgroup
  *	@sk_cgrp_data: cgroup data for this cgroup
  *	@sk_memcg: this socket's memory cgroup association
  *	@sk_memcg: this socket's memory cgroup association
  *	@sk_write_pending: a write to stream socket waits to start
  *	@sk_write_pending: a write to stream socket waits to start
  *	@sk_wait_pending: number of threads blocked on this socket
  *	@sk_disconnects: number of disconnect operations performed on this sock
  *	@sk_state_change: callback to indicate change in the state of the sock
  *	@sk_state_change: callback to indicate change in the state of the sock
  *	@sk_data_ready: callback to indicate there is data to be processed
  *	@sk_data_ready: callback to indicate there is data to be processed
  *	@sk_write_space: callback to indicate there is bf sending space available
  *	@sk_write_space: callback to indicate there is bf sending space available
@@ -429,7 +429,7 @@ struct sock {
	unsigned int		sk_napi_id;
	unsigned int		sk_napi_id;
#endif
#endif
	int			sk_rcvbuf;
	int			sk_rcvbuf;
	int			sk_wait_pending;
	int			sk_disconnects;


	struct sk_filter __rcu	*sk_filter;
	struct sk_filter __rcu	*sk_filter;
	union {
	union {
@@ -1189,8 +1189,7 @@ static inline void sock_rps_reset_rxhash(struct sock *sk)
}
}


#define sk_wait_event(__sk, __timeo, __condition, __wait)		\
#define sk_wait_event(__sk, __timeo, __condition, __wait)		\
	({	int __rc;						\
	({	int __rc, __dis = __sk->sk_disconnects;			\
		__sk->sk_wait_pending++;				\
		release_sock(__sk);					\
		release_sock(__sk);					\
		__rc = __condition;					\
		__rc = __condition;					\
		if (!__rc) {						\
		if (!__rc) {						\
@@ -1200,8 +1199,7 @@ static inline void sock_rps_reset_rxhash(struct sock *sk)
		}							\
		}							\
		sched_annotate_sleep();					\
		sched_annotate_sleep();					\
		lock_sock(__sk);					\
		lock_sock(__sk);					\
		__sk->sk_wait_pending--;				\
		__rc = __dis == __sk->sk_disconnects ? __condition : -EPIPE; \
		__rc = __condition;					\
		__rc;							\
		__rc;							\
	})
	})


+7 −5
Original line number Original line Diff line number Diff line
@@ -117,7 +117,7 @@ EXPORT_SYMBOL(sk_stream_wait_close);
 */
 */
int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
{
{
	int err = 0;
	int ret, err = 0;
	long vm_wait = 0;
	long vm_wait = 0;
	long current_timeo = *timeo_p;
	long current_timeo = *timeo_p;
	DEFINE_WAIT_FUNC(wait, woken_wake_function);
	DEFINE_WAIT_FUNC(wait, woken_wake_function);
@@ -142,11 +142,13 @@ int sk_stream_wait_memory(struct sock *sk, long *timeo_p)


		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
		sk->sk_write_pending++;
		sk->sk_write_pending++;
		sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) ||
		ret = sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) ||
				    (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) ||
				    (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) ||
						  (sk_stream_memory_free(sk) &&
				    (sk_stream_memory_free(sk) && !vm_wait),
						  !vm_wait), &wait);
				    &wait);
		sk->sk_write_pending--;
		sk->sk_write_pending--;
		if (ret < 0)
			goto do_error;


		if (vm_wait) {
		if (vm_wait) {
			vm_wait -= current_timeo;
			vm_wait -= current_timeo;
+8 −2
Original line number Original line Diff line number Diff line
@@ -597,7 +597,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)


	add_wait_queue(sk_sleep(sk), &wait);
	add_wait_queue(sk_sleep(sk), &wait);
	sk->sk_write_pending += writebias;
	sk->sk_write_pending += writebias;
	sk->sk_wait_pending++;


	/* Basic assumption: if someone sets sk->sk_err, he _must_
	/* Basic assumption: if someone sets sk->sk_err, he _must_
	 * change state of the socket from TCP_SYN_*.
	 * change state of the socket from TCP_SYN_*.
@@ -613,7 +612,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)
	}
	}
	remove_wait_queue(sk_sleep(sk), &wait);
	remove_wait_queue(sk_sleep(sk), &wait);
	sk->sk_write_pending -= writebias;
	sk->sk_write_pending -= writebias;
	sk->sk_wait_pending--;
	return timeo;
	return timeo;
}
}


@@ -642,6 +640,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
			return -EINVAL;
			return -EINVAL;


		if (uaddr->sa_family == AF_UNSPEC) {
		if (uaddr->sa_family == AF_UNSPEC) {
			sk->sk_disconnects++;
			err = sk->sk_prot->disconnect(sk, flags);
			err = sk->sk_prot->disconnect(sk, flags);
			sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED;
			sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED;
			goto out;
			goto out;
@@ -696,6 +695,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
		int writebias = (sk->sk_protocol == IPPROTO_TCP) &&
		int writebias = (sk->sk_protocol == IPPROTO_TCP) &&
				tcp_sk(sk)->fastopen_req &&
				tcp_sk(sk)->fastopen_req &&
				tcp_sk(sk)->fastopen_req->data ? 1 : 0;
				tcp_sk(sk)->fastopen_req->data ? 1 : 0;
		int dis = sk->sk_disconnects;


		/* Error code is set above */
		/* Error code is set above */
		if (!timeo || !inet_wait_for_connect(sk, timeo, writebias))
		if (!timeo || !inet_wait_for_connect(sk, timeo, writebias))
@@ -704,6 +704,11 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
		err = sock_intr_errno(timeo);
		err = sock_intr_errno(timeo);
		if (signal_pending(current))
		if (signal_pending(current))
			goto out;
			goto out;

		if (dis != sk->sk_disconnects) {
			err = -EPIPE;
			goto out;
		}
	}
	}


	/* Connection was closed by RST, timeout, ICMP error
	/* Connection was closed by RST, timeout, ICMP error
@@ -725,6 +730,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
sock_error:
sock_error:
	err = sock_error(sk) ? : -ECONNABORTED;
	err = sock_error(sk) ? : -ECONNABORTED;
	sock->state = SS_UNCONNECTED;
	sock->state = SS_UNCONNECTED;
	sk->sk_disconnects++;
	if (sk->sk_prot->disconnect(sk, flags))
	if (sk->sk_prot->disconnect(sk, flags))
		sock->state = SS_DISCONNECTING;
		sock->state = SS_DISCONNECTING;
	goto out;
	goto out;
+0 −1
Original line number Original line Diff line number Diff line
@@ -1145,7 +1145,6 @@ struct sock *inet_csk_clone_lock(const struct sock *sk,
	if (newsk) {
	if (newsk) {
		struct inet_connection_sock *newicsk = inet_csk(newsk);
		struct inet_connection_sock *newicsk = inet_csk(newsk);


		newsk->sk_wait_pending = 0;
		inet_sk_set_state(newsk, TCP_SYN_RECV);
		inet_sk_set_state(newsk, TCP_SYN_RECV);
		newicsk->icsk_bind_hash = NULL;
		newicsk->icsk_bind_hash = NULL;
		newicsk->icsk_bind2_hash = NULL;
		newicsk->icsk_bind2_hash = NULL;
Loading