Commit 29003875 authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Alexei Starovoitov
Browse files

bpf: Change bpf_setsockopt(SOL_SOCKET) to reuse sk_setsockopt()



After the prep work in the previous patches,
this patch removes most of the dup code from bpf_setsockopt(SOL_SOCKET)
and reuses them from sk_setsockopt().

The sock ptr test is added to the SO_RCVLOWAT because
the sk->sk_socket could be NULL in some of the bpf hooks.

The existing optname white-list is refactored into a new
function sol_socket_setsockopt().

Reviewed-by: default avatarStanislav Fomichev <sdf@google.com>
Signed-off-by: default avatarMartin KaFai Lau <kafai@fb.com>
Link: https://lore.kernel.org/r/20220817061804.4178920-1-kafai@fb.com


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent ebf9e8e6
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -1828,6 +1828,8 @@ void sock_pfree(struct sk_buff *skb);
#define sock_edemux sock_efree
#endif

int sk_setsockopt(struct sock *sk, int level, int optname,
		  sockptr_t optval, unsigned int optlen);
int sock_setsockopt(struct socket *sock, int level, int op,
		    sockptr_t optval, unsigned int optlen);

+29 −95
Original line number Diff line number Diff line
@@ -5013,109 +5013,43 @@ static const struct bpf_func_proto bpf_get_socket_uid_proto = {
	.arg1_type      = ARG_PTR_TO_CTX,
};

static int __bpf_setsockopt(struct sock *sk, int level, int optname,
static int sol_socket_setsockopt(struct sock *sk, int optname,
				 char *optval, int optlen)
{
	char devname[IFNAMSIZ];
	int val, valbool;
	struct net *net;
	int ifindex;
	int ret = 0;

	if (!sk_fullsock(sk))
		return -EINVAL;

	if (level == SOL_SOCKET) {
		if (optlen != sizeof(int) && optname != SO_BINDTODEVICE)
			return -EINVAL;
		val = *((int *)optval);
		valbool = val ? 1 : 0;

		/* Only some socketops are supported */
	switch (optname) {
		case SO_RCVBUF:
			val = min_t(u32, val, sysctl_rmem_max);
			val = min_t(int, val, INT_MAX / 2);
			sk->sk_userlocks |= SOCK_RCVBUF_LOCK;
			WRITE_ONCE(sk->sk_rcvbuf,
				   max_t(int, val * 2, SOCK_MIN_RCVBUF));
			break;
	case SO_SNDBUF:
			val = min_t(u32, val, sysctl_wmem_max);
			val = min_t(int, val, INT_MAX / 2);
			sk->sk_userlocks |= SOCK_SNDBUF_LOCK;
			WRITE_ONCE(sk->sk_sndbuf,
				   max_t(int, val * 2, SOCK_MIN_SNDBUF));
			break;
		case SO_MAX_PACING_RATE: /* 32bit version */
			if (val != ~0U)
				cmpxchg(&sk->sk_pacing_status,
					SK_PACING_NONE,
					SK_PACING_NEEDED);
			sk->sk_max_pacing_rate = (val == ~0U) ?
						 ~0UL : (unsigned int)val;
			sk->sk_pacing_rate = min(sk->sk_pacing_rate,
						 sk->sk_max_pacing_rate);
			break;
	case SO_RCVBUF:
	case SO_KEEPALIVE:
	case SO_PRIORITY:
			sk->sk_priority = val;
			break;
	case SO_REUSEPORT:
	case SO_RCVLOWAT:
			if (val < 0)
				val = INT_MAX;
			if (sk->sk_socket && sk->sk_socket->ops->set_rcvlowat)
				ret = sk->sk_socket->ops->set_rcvlowat(sk, val);
			else
				WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
			break;
	case SO_MARK:
			if (sk->sk_mark != val) {
				sk->sk_mark = val;
				sk_dst_reset(sk);
			}
			break;
		case SO_BINDTODEVICE:
			optlen = min_t(long, optlen, IFNAMSIZ - 1);
			strncpy(devname, optval, optlen);
			devname[optlen] = 0;

			ifindex = 0;
			if (devname[0] != '\0') {
				struct net_device *dev;

				ret = -ENODEV;

				net = sock_net(sk);
				dev = dev_get_by_name(net, devname);
				if (!dev)
					break;
				ifindex = dev->ifindex;
				dev_put(dev);
			}
			fallthrough;
	case SO_MAX_PACING_RATE:
	case SO_BINDTOIFINDEX:
			if (optname == SO_BINDTOIFINDEX)
				ifindex = val;
			ret = sock_bindtoindex(sk, ifindex, false);
			break;
		case SO_KEEPALIVE:
			if (sk->sk_prot->keepalive)
				sk->sk_prot->keepalive(sk, valbool);
			sock_valbool_flag(sk, SOCK_KEEPOPEN, valbool);
			break;
		case SO_REUSEPORT:
			sk->sk_reuseport = valbool;
			break;
	case SO_TXREHASH:
			if (val < -1 || val > 1) {
				ret = -EINVAL;
		if (optlen != sizeof(int))
			return -EINVAL;
		break;
			}
			sk->sk_txrehash = (u8)val;
	case SO_BINDTODEVICE:
		break;
	default:
			ret = -EINVAL;
		return -EINVAL;
	}

	return sk_setsockopt(sk, SOL_SOCKET, optname,
			     KERNEL_SOCKPTR(optval), optlen);
}

static int __bpf_setsockopt(struct sock *sk, int level, int optname,
			    char *optval, int optlen)
{
	int val, ret = 0;

	if (!sk_fullsock(sk))
		return -EINVAL;

	if (level == SOL_SOCKET) {
		return sol_socket_setsockopt(sk, optname, optval, optlen);
	} else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP) {
		if (optlen != sizeof(int) || sk->sk_family != AF_INET)
			return -EINVAL;
+3 −3
Original line number Diff line number Diff line
@@ -1077,7 +1077,7 @@ EXPORT_SYMBOL(sockopt_capable);
 *	at the socket level. Everything here is generic.
 */

static int sk_setsockopt(struct sock *sk, int level, int optname,
int sk_setsockopt(struct sock *sk, int level, int optname,
		  sockptr_t optval, unsigned int optlen)
{
	struct so_timestamping timestamping;
@@ -1264,7 +1264,7 @@ static int sk_setsockopt(struct sock *sk, int level, int optname,
	case SO_RCVLOWAT:
		if (val < 0)
			val = INT_MAX;
		if (sock->ops->set_rcvlowat)
		if (sock && sock->ops->set_rcvlowat)
			ret = sock->ops->set_rcvlowat(sk, val);
		else
			WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);