Commit 36397a18 authored by Martin KaFai Lau's avatar Martin KaFai Lau
Browse files

Merge branch 'Add SO_REUSEPORT support for TC bpf_sk_assign'



Lorenz Bauer says:

====================
We want to replace iptables TPROXY with a BPF program at TC ingress.
To make this work in all cases we need to assign a SO_REUSEPORT socket
to an skb, which is currently prohibited. This series adds support for
such sockets to bpf_sk_assing.

I did some refactoring to cut down on the amount of duplicate code. The
key to this is to use INDIRECT_CALL in the reuseport helpers. To show
that this approach is not just beneficial to TC sk_assign I removed
duplicate code for bpf_sk_lookup as well.

Joint work with Daniel Borkmann.

Signed-off-by: default avatarLorenz Bauer <lmb@isovalent.com>
---
Changes in v6:
- Reject unhashed UDP sockets in bpf_sk_assign to avoid ref leak
- Link to v5: https://lore.kernel.org/r/20230613-so-reuseport-v5-0-f6686a0dbce0@isovalent.com

Changes in v5:
- Drop reuse_sk == sk check in inet[6]_steal_stock (Kuniyuki)
- Link to v4: https://lore.kernel.org/r/20230613-so-reuseport-v4-0-4ece76708bba@isovalent.com

Changes in v4:
- WARN_ON_ONCE if reuseport socket is refcounted (Kuniyuki)
- Use inet[6]_ehashfn_t to shorten function declarations (Kuniyuki)
- Shuffle documentation patch around (Kuniyuki)
- Update commit message to explain why IPv6 needs EXPORT_SYMBOL
- Link to v3: https://lore.kernel.org/r/20230613-so-reuseport-v3-0-907b4cbb7b99@isovalent.com

Changes in v3:
- Fix warning re udp_ehashfn and udp6_ehashfn (Simon)
- Return higher scoring connected UDP reuseport sockets (Kuniyuki)
- Fix ipv6 module builds
- Link to v2: https://lore.kernel.org/r/20230613-so-reuseport-v2-0-b7c69a342613@isovalent.com



Changes in v2:
- Correct commit abbrev length (Kuniyuki)
- Reduce duplication (Kuniyuki)
- Add checks on sk_state (Martin)
- Split exporting inet[6]_lookup_reuseport into separate patch (Eric)

---
Daniel Borkmann (1):
      selftests/bpf: Test that SO_REUSEPORT can be used with sk_assign helper
====================

Signed-off-by: default avatarMartin KaFai Lau <martin.lau@kernel.org>
parents 7b2b2012 22408d58
Loading
Loading
Loading
Loading
+76 −5
Original line number Diff line number Diff line
@@ -48,6 +48,22 @@ struct sock *__inet6_lookup_established(struct net *net,
					const u16 hnum, const int dif,
					const int sdif);

typedef u32 (inet6_ehashfn_t)(const struct net *net,
			       const struct in6_addr *laddr, const u16 lport,
			       const struct in6_addr *faddr, const __be16 fport);

inet6_ehashfn_t inet6_ehashfn;

INDIRECT_CALLABLE_DECLARE(inet6_ehashfn_t udp6_ehashfn);

struct sock *inet6_lookup_reuseport(struct net *net, struct sock *sk,
				    struct sk_buff *skb, int doff,
				    const struct in6_addr *saddr,
				    __be16 sport,
				    const struct in6_addr *daddr,
				    unsigned short hnum,
				    inet6_ehashfn_t *ehashfn);

struct sock *inet6_lookup_listener(struct net *net,
				   struct inet_hashinfo *hashinfo,
				   struct sk_buff *skb, int doff,
@@ -57,6 +73,15 @@ struct sock *inet6_lookup_listener(struct net *net,
				   const unsigned short hnum,
				   const int dif, const int sdif);

struct sock *inet6_lookup_run_sk_lookup(struct net *net,
					int protocol,
					struct sk_buff *skb, int doff,
					const struct in6_addr *saddr,
					const __be16 sport,
					const struct in6_addr *daddr,
					const u16 hnum, const int dif,
					inet6_ehashfn_t *ehashfn);

static inline struct sock *__inet6_lookup(struct net *net,
					  struct inet_hashinfo *hashinfo,
					  struct sk_buff *skb, int doff,
@@ -78,6 +103,46 @@ static inline struct sock *__inet6_lookup(struct net *net,
				     daddr, hnum, dif, sdif);
}

static inline
struct sock *inet6_steal_sock(struct net *net, struct sk_buff *skb, int doff,
			      const struct in6_addr *saddr, const __be16 sport,
			      const struct in6_addr *daddr, const __be16 dport,
			      bool *refcounted, inet6_ehashfn_t *ehashfn)
{
	struct sock *sk, *reuse_sk;
	bool prefetched;

	sk = skb_steal_sock(skb, refcounted, &prefetched);
	if (!sk)
		return NULL;

	if (!prefetched)
		return sk;

	if (sk->sk_protocol == IPPROTO_TCP) {
		if (sk->sk_state != TCP_LISTEN)
			return sk;
	} else if (sk->sk_protocol == IPPROTO_UDP) {
		if (sk->sk_state != TCP_CLOSE)
			return sk;
	} else {
		return sk;
	}

	reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
					  saddr, sport, daddr, ntohs(dport),
					  ehashfn);
	if (!reuse_sk)
		return sk;

	/* We've chosen a new reuseport sock which is never refcounted. This
	 * implies that sk also isn't refcounted.
	 */
	WARN_ON_ONCE(*refcounted);

	return reuse_sk;
}

static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
					      struct sk_buff *skb, int doff,
					      const __be16 sport,
@@ -85,14 +150,20 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
					      int iif, int sdif,
					      bool *refcounted)
{
	struct sock *sk = skb_steal_sock(skb, refcounted);

	struct net *net = dev_net(skb_dst(skb)->dev);
	const struct ipv6hdr *ip6h = ipv6_hdr(skb);
	struct sock *sk;

	sk = inet6_steal_sock(net, skb, doff, &ip6h->saddr, sport, &ip6h->daddr, dport,
			      refcounted, inet6_ehashfn);
	if (IS_ERR(sk))
		return NULL;
	if (sk)
		return sk;

	return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
			      doff, &ipv6_hdr(skb)->saddr, sport,
			      &ipv6_hdr(skb)->daddr, ntohs(dport),
	return __inet6_lookup(net, hashinfo, skb,
			      doff, &ip6h->saddr, sport,
			      &ip6h->daddr, ntohs(dport),
			      iif, sdif, refcounted);
}

+68 −6
Original line number Diff line number Diff line
@@ -379,6 +379,27 @@ struct sock *__inet_lookup_established(struct net *net,
				       const __be32 daddr, const u16 hnum,
				       const int dif, const int sdif);

typedef u32 (inet_ehashfn_t)(const struct net *net,
			      const __be32 laddr, const __u16 lport,
			      const __be32 faddr, const __be16 fport);

inet_ehashfn_t inet_ehashfn;

INDIRECT_CALLABLE_DECLARE(inet_ehashfn_t udp_ehashfn);

struct sock *inet_lookup_reuseport(struct net *net, struct sock *sk,
				   struct sk_buff *skb, int doff,
				   __be32 saddr, __be16 sport,
				   __be32 daddr, unsigned short hnum,
				   inet_ehashfn_t *ehashfn);

struct sock *inet_lookup_run_sk_lookup(struct net *net,
				       int protocol,
				       struct sk_buff *skb, int doff,
				       __be32 saddr, __be16 sport,
				       __be32 daddr, u16 hnum, const int dif,
				       inet_ehashfn_t *ehashfn);

static inline struct sock *
	inet_lookup_established(struct net *net, struct inet_hashinfo *hashinfo,
				const __be32 saddr, const __be16 sport,
@@ -428,6 +449,46 @@ static inline struct sock *inet_lookup(struct net *net,
	return sk;
}

static inline
struct sock *inet_steal_sock(struct net *net, struct sk_buff *skb, int doff,
			     const __be32 saddr, const __be16 sport,
			     const __be32 daddr, const __be16 dport,
			     bool *refcounted, inet_ehashfn_t *ehashfn)
{
	struct sock *sk, *reuse_sk;
	bool prefetched;

	sk = skb_steal_sock(skb, refcounted, &prefetched);
	if (!sk)
		return NULL;

	if (!prefetched)
		return sk;

	if (sk->sk_protocol == IPPROTO_TCP) {
		if (sk->sk_state != TCP_LISTEN)
			return sk;
	} else if (sk->sk_protocol == IPPROTO_UDP) {
		if (sk->sk_state != TCP_CLOSE)
			return sk;
	} else {
		return sk;
	}

	reuse_sk = inet_lookup_reuseport(net, sk, skb, doff,
					 saddr, sport, daddr, ntohs(dport),
					 ehashfn);
	if (!reuse_sk)
		return sk;

	/* We've chosen a new reuseport sock which is never refcounted. This
	 * implies that sk also isn't refcounted.
	 */
	WARN_ON_ONCE(*refcounted);

	return reuse_sk;
}

static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
					     struct sk_buff *skb,
					     int doff,
@@ -436,22 +497,23 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
					     const int sdif,
					     bool *refcounted)
{
	struct sock *sk = skb_steal_sock(skb, refcounted);
	struct net *net = dev_net(skb_dst(skb)->dev);
	const struct iphdr *iph = ip_hdr(skb);
	struct sock *sk;

	sk = inet_steal_sock(net, skb, doff, iph->saddr, sport, iph->daddr, dport,
			     refcounted, inet_ehashfn);
	if (IS_ERR(sk))
		return NULL;
	if (sk)
		return sk;

	return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
	return __inet_lookup(net, hashinfo, skb,
			     doff, iph->saddr, sport,
			     iph->daddr, dport, inet_iif(skb), sdif,
			     refcounted);
}

u32 inet6_ehashfn(const struct net *net,
		  const struct in6_addr *laddr, const u16 lport,
		  const struct in6_addr *faddr, const __be16 fport);

static inline void sk_daddr_set(struct sock *sk, __be32 addr)
{
	sk->sk_daddr = addr; /* alias of inet_daddr */
+5 −2
Original line number Diff line number Diff line
@@ -2815,20 +2815,23 @@ sk_is_refcounted(struct sock *sk)
 * skb_steal_sock - steal a socket from an sk_buff
 * @skb: sk_buff to steal the socket from
 * @refcounted: is set to true if the socket is reference-counted
 * @prefetched: is set to true if the socket was assigned from bpf
 */
static inline struct sock *
skb_steal_sock(struct sk_buff *skb, bool *refcounted)
skb_steal_sock(struct sk_buff *skb, bool *refcounted, bool *prefetched)
{
	if (skb->sk) {
		struct sock *sk = skb->sk;

		*refcounted = true;
		if (skb_sk_is_prefetched(skb))
		*prefetched = skb_sk_is_prefetched(skb);
		if (*prefetched)
			*refcounted = sk_is_refcounted(sk);
		skb->destructor = NULL;
		skb->sk = NULL;
		return sk;
	}
	*prefetched = false;
	*refcounted = false;
	return NULL;
}
+0 −3
Original line number Diff line number Diff line
@@ -4198,9 +4198,6 @@ union bpf_attr {
 *		**-EOPNOTSUPP** if the operation is not supported, for example
 *		a call from outside of TC ingress.
 *
 *		**-ESOCKTNOSUPPORT** if the socket type is not supported
 *		(reuseport).
 *
 * long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
 *	Description
 *		Helper is overloaded depending on BPF program type. This
+2 −2
Original line number Diff line number Diff line
@@ -7351,8 +7351,8 @@ BPF_CALL_3(bpf_sk_assign, struct sk_buff *, skb, struct sock *, sk, u64, flags)
		return -EOPNOTSUPP;
	if (unlikely(dev_net(skb->dev) != sock_net(sk)))
		return -ENETUNREACH;
	if (unlikely(sk_fullsock(sk) && sk->sk_reuseport))
		return -ESOCKTNOSUPPORT;
	if (sk_unhashed(sk))
		return -EOPNOTSUPP;
	if (sk_is_refcounted(sk) &&
	    unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
		return -ENOENT;
Loading