Commit 51e0158a authored by Cong Wang's avatar Cong Wang Committed by Daniel Borkmann
Browse files

skmsg: Pass psock pointer to ->psock_update_sk_prot()



Using sk_psock() to retrieve psock pointer from sock requires
RCU read lock, but we already get psock pointer before calling
->psock_update_sk_prot() in both cases, so we can just pass it
without bothering sk_psock().

Fixes: 8a59f9d1 ("sock: Introduce sk->sk_prot->psock_update_sk_prot()")
Reported-by: default avatar <syzbot+320a3bc8d80f478c37e4@syzkaller.appspotmail.com>
Signed-off-by: default avatarCong Wang <cong.wang@bytedance.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Tested-by: default avatar <syzbot+320a3bc8d80f478c37e4@syzkaller.appspotmail.com>
Reviewed-by: default avatarJakub Sitnicki <jakub@cloudflare.com>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20210407032111.33398-1-xiyou.wangcong@gmail.com
parent cbaa683b
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -99,7 +99,8 @@ struct sk_psock {
	void (*saved_close)(struct sock *sk, long timeout);
	void (*saved_write_space)(struct sock *sk);
	void (*saved_data_ready)(struct sock *sk);
	int  (*psock_update_sk_prot)(struct sock *sk, bool restore);
	int  (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
				     bool restore);
	struct proto			*sk_proto;
	struct mutex			work_mutex;
	struct sk_psock_work_state	work_state;
@@ -405,7 +406,7 @@ static inline void sk_psock_restore_proto(struct sock *sk,
{
	sk->sk_prot->unhash = psock->saved_unhash;
	if (psock->psock_update_sk_prot)
		psock->psock_update_sk_prot(sk, true);
		psock->psock_update_sk_prot(sk, psock, true);
}

static inline void sk_psock_set_state(struct sk_psock *psock,
+4 −1
Original line number Diff line number Diff line
@@ -1114,6 +1114,7 @@ struct inet_hashinfo;
struct raw_hashinfo;
struct smc_hashinfo;
struct module;
struct sk_psock;

/*
 * caches using SLAB_TYPESAFE_BY_RCU should let .next pointer from nulls nodes
@@ -1185,7 +1186,9 @@ struct proto {
	void			(*rehash)(struct sock *sk);
	int			(*get_port)(struct sock *sk, unsigned short snum);
#ifdef CONFIG_BPF_SYSCALL
	int			(*psock_update_sk_prot)(struct sock *sk, bool restore);
	int			(*psock_update_sk_prot)(struct sock *sk,
							struct sk_psock *psock,
							bool restore);
#endif

	/* Keeping track of sockets in use */
+1 −1
Original line number Diff line number Diff line
@@ -2215,7 +2215,7 @@ struct sk_psock;

#ifdef CONFIG_BPF_SYSCALL
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int tcp_bpf_update_proto(struct sock *sk, bool restore);
int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
#endif /* CONFIG_BPF_SYSCALL */

+1 −1
Original line number Diff line number Diff line
@@ -543,7 +543,7 @@ static inline void udp_post_segment_fix_csum(struct sk_buff *skb)
#ifdef CONFIG_BPF_SYSCALL
struct sk_psock;
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int udp_bpf_update_proto(struct sock *sk, bool restore);
int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
#endif

#endif	/* _UDP_H */
+1 −1
Original line number Diff line number Diff line
@@ -188,7 +188,7 @@ static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
	if (!sk->sk_prot->psock_update_sk_prot)
		return -EINVAL;
	psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
	return sk->sk_prot->psock_update_sk_prot(sk, false);
	return sk->sk_prot->psock_update_sk_prot(sk, psock, false);
}

static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
Loading