Commit 6fd815bb authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'wireguard-fixes'



Jason A. Donenfeld says:

====================
wireguard fixes for 5.13-rc5

Here are bug fixes to WireGuard for 5.13-rc5:

1-2,6) These are small, trivial tweaks to our test harness.

3) Linus thinks -O3 is still dangerous to enable. The code gen wasn't so
   much different with -O2 either.

4) We were accidentally calling synchronize_rcu instead of
   synchronize_net while holding the rtnl_lock, resulting in some rather
   large stalls that hit production machines.

5) Peer allocation was wasting literally hundreds of megabytes on real
   world deployments, due to oddly sized large objects not fitting
   nicely into a kmalloc slab.

7-9) We move from an insanely expensive O(n) algorithm to a fast O(1)
     algorithm, and cleanup a massive memory leak in the process, in
     which allowed ips churn would leave danging nodes hanging around
     without cleanup until the interface was removed. The O(1) algorithm
     eliminates packet stalls and high latency issues, in addition to
     bringing operations that took as much as 10 minutes down to less
     than a second.
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 579028de bf7b042d
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
ccflags-y := -O3
ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
ccflags-y := -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG
wireguard-y := main.o
wireguard-y += noise.o
+99 −90
Original line number Diff line number Diff line
@@ -6,6 +6,8 @@
#include "allowedips.h"
#include "peer.h"

static struct kmem_cache *node_cache;

static void swap_endian(u8 *dst, const u8 *src, u8 bits)
{
	if (bits == 32) {
@@ -28,8 +30,11 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
	node->bitlen = bits;
	memcpy(node->bits, src, bits / 8U);
}
#define CHOOSE_NODE(parent, key) \
	parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]

static inline u8 choose(struct allowedips_node *node, const u8 *key)
{
	return (key[node->bit_at_a] >> node->bit_at_b) & 1;
}

static void push_rcu(struct allowedips_node **stack,
		     struct allowedips_node __rcu *p, unsigned int *len)
@@ -40,6 +45,11 @@ static void push_rcu(struct allowedips_node **stack,
	}
}

static void node_free_rcu(struct rcu_head *rcu)
{
	kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu));
}

static void root_free_rcu(struct rcu_head *rcu)
{
	struct allowedips_node *node, *stack[128] = {
@@ -49,7 +59,7 @@ static void root_free_rcu(struct rcu_head *rcu)
	while (len > 0 && (node = stack[--len])) {
		push_rcu(stack, node->bit[0], &len);
		push_rcu(stack, node->bit[1], &len);
		kfree(node);
		kmem_cache_free(node_cache, node);
	}
}

@@ -66,60 +76,6 @@ static void root_remove_peer_lists(struct allowedips_node *root)
	}
}

static void walk_remove_by_peer(struct allowedips_node __rcu **top,
				struct wg_peer *peer, struct mutex *lock)
{
#define REF(p) rcu_access_pointer(p)
#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
#define PUSH(p) ({                                                             \
		WARN_ON(IS_ENABLED(DEBUG) && len >= 128);                      \
		stack[len++] = p;                                              \
	})

	struct allowedips_node __rcu **stack[128], **nptr;
	struct allowedips_node *node, *prev;
	unsigned int len;

	if (unlikely(!peer || !REF(*top)))
		return;

	for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
		nptr = stack[len - 1];
		node = DEREF(nptr);
		if (!node) {
			--len;
			continue;
		}
		if (!prev || REF(prev->bit[0]) == node ||
		    REF(prev->bit[1]) == node) {
			if (REF(node->bit[0]))
				PUSH(&node->bit[0]);
			else if (REF(node->bit[1]))
				PUSH(&node->bit[1]);
		} else if (REF(node->bit[0]) == prev) {
			if (REF(node->bit[1]))
				PUSH(&node->bit[1]);
		} else {
			if (rcu_dereference_protected(node->peer,
				lockdep_is_held(lock)) == peer) {
				RCU_INIT_POINTER(node->peer, NULL);
				list_del_init(&node->peer_list);
				if (!node->bit[0] || !node->bit[1]) {
					rcu_assign_pointer(*nptr, DEREF(
					       &node->bit[!REF(node->bit[0])]));
					kfree_rcu(node, rcu);
					node = DEREF(nptr);
				}
			}
			--len;
		}
	}

#undef REF
#undef DEREF
#undef PUSH
}

static unsigned int fls128(u64 a, u64 b)
{
	return a ? fls64(a) + 64U : fls64(b);
@@ -159,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits,
			found = node;
		if (node->cidr == bits)
			break;
		node = rcu_dereference_bh(CHOOSE_NODE(node, key));
		node = rcu_dereference_bh(node->bit[choose(node, key)]);
	}
	return found;
}
@@ -191,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
			   u8 cidr, u8 bits, struct allowedips_node **rnode,
			   struct mutex *lock)
{
	struct allowedips_node *node = rcu_dereference_protected(trie,
						lockdep_is_held(lock));
	struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
	struct allowedips_node *parent = NULL;
	bool exact = false;

@@ -202,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
			exact = true;
			break;
		}
		node = rcu_dereference_protected(CHOOSE_NODE(parent, key),
						 lockdep_is_held(lock));
		node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock));
	}
	*rnode = parent;
	return exact;
}

static inline void connect_node(struct allowedips_node **parent, u8 bit, struct allowedips_node *node)
{
	node->parent_bit_packed = (unsigned long)parent | bit;
	rcu_assign_pointer(*parent, node);
}

static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node)
{
	u8 bit = choose(parent, node->bits);
	connect_node(&parent->bit[bit], bit, node);
}

static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
	       u8 cidr, struct wg_peer *peer, struct mutex *lock)
{
@@ -218,13 +184,13 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
		return -EINVAL;

	if (!rcu_access_pointer(*trie)) {
		node = kzalloc(sizeof(*node), GFP_KERNEL);
		node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
		if (unlikely(!node))
			return -ENOMEM;
		RCU_INIT_POINTER(node->peer, peer);
		list_add_tail(&node->peer_list, &peer->allowedips_list);
		copy_and_assign_cidr(node, key, cidr, bits);
		rcu_assign_pointer(*trie, node);
		connect_node(trie, 2, node);
		return 0;
	}
	if (node_placement(*trie, key, cidr, bits, &node, lock)) {
@@ -233,7 +199,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
		return 0;
	}

	newnode = kzalloc(sizeof(*newnode), GFP_KERNEL);
	newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL);
	if (unlikely(!newnode))
		return -ENOMEM;
	RCU_INIT_POINTER(newnode->peer, peer);
@@ -243,10 +209,10 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
	if (!node) {
		down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
	} else {
		down = rcu_dereference_protected(CHOOSE_NODE(node, key),
						 lockdep_is_held(lock));
		const u8 bit = choose(node, key);
		down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock));
		if (!down) {
			rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
			connect_node(&node->bit[bit], bit, newnode);
			return 0;
		}
	}
@@ -254,30 +220,29 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
	parent = node;

	if (newnode->cidr == cidr) {
		rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
		choose_and_connect_node(newnode, down);
		if (!parent)
			rcu_assign_pointer(*trie, newnode);
			connect_node(trie, 2, newnode);
		else
			rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits),
					   newnode);
	} else {
		node = kzalloc(sizeof(*node), GFP_KERNEL);
			choose_and_connect_node(parent, newnode);
		return 0;
	}

	node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
	if (unlikely(!node)) {
		list_del(&newnode->peer_list);
			kfree(newnode);
		kmem_cache_free(node_cache, newnode);
		return -ENOMEM;
	}
	INIT_LIST_HEAD(&node->peer_list);
	copy_and_assign_cidr(node, newnode->bits, cidr, bits);

		rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
		rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
	choose_and_connect_node(node, down);
	choose_and_connect_node(node, newnode);
	if (!parent)
			rcu_assign_pointer(*trie, node);
		connect_node(trie, 2, node);
	else
			rcu_assign_pointer(CHOOSE_NODE(parent, node->bits),
					   node);
	}
		choose_and_connect_node(parent, node);
	return 0;
}

@@ -335,9 +300,41 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
void wg_allowedips_remove_by_peer(struct allowedips *table,
				  struct wg_peer *peer, struct mutex *lock)
{
	struct allowedips_node *node, *child, **parent_bit, *parent, *tmp;
	bool free_parent;

	if (list_empty(&peer->allowedips_list))
		return;
	++table->seq;
	walk_remove_by_peer(&table->root4, peer, lock);
	walk_remove_by_peer(&table->root6, peer, lock);
	list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
		list_del_init(&node->peer_list);
		RCU_INIT_POINTER(node->peer, NULL);
		if (node->bit[0] && node->bit[1])
			continue;
		child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
						  lockdep_is_held(lock));
		if (child)
			child->parent_bit_packed = node->parent_bit_packed;
		parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
		*parent_bit = child;
		parent = (void *)parent_bit -
			 offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
		free_parent = !rcu_access_pointer(node->bit[0]) &&
			      !rcu_access_pointer(node->bit[1]) &&
			      (node->parent_bit_packed & 3) <= 1 &&
			      !rcu_access_pointer(parent->peer);
		if (free_parent)
			child = rcu_dereference_protected(
					parent->bit[!(node->parent_bit_packed & 1)],
					lockdep_is_held(lock));
		call_rcu(&node->rcu, node_free_rcu);
		if (!free_parent)
			continue;
		if (child)
			child->parent_bit_packed = parent->parent_bit_packed;
		*(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
		call_rcu(&parent->rcu, node_free_rcu);
	}
}

int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
@@ -374,4 +371,16 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
	return NULL;
}

int __init wg_allowedips_slab_init(void)
{
	node_cache = KMEM_CACHE(allowedips_node, 0);
	return node_cache ? 0 : -ENOMEM;
}

void wg_allowedips_slab_uninit(void)
{
	rcu_barrier();
	kmem_cache_destroy(node_cache);
}

#include "selftest/allowedips.c"
+7 −7
Original line number Diff line number Diff line
@@ -15,14 +15,11 @@ struct wg_peer;
struct allowedips_node {
	struct wg_peer __rcu *peer;
	struct allowedips_node __rcu *bit[2];
	/* While it may seem scandalous that we waste space for v4,
	 * we're alloc'ing to the nearest power of 2 anyway, so this
	 * doesn't actually make a difference.
	 */
	u8 bits[16] __aligned(__alignof(u64));
	u8 cidr, bit_at_a, bit_at_b, bitlen;
	u8 bits[16] __aligned(__alignof(u64));

	/* Keep rarely used list at bottom to be beyond cache line. */
	/* Keep rarely used members at bottom to be beyond cache line. */
	unsigned long parent_bit_packed;
	union {
		struct list_head peer_list;
		struct rcu_head rcu;
@@ -33,7 +30,7 @@ struct allowedips {
	struct allowedips_node __rcu *root4;
	struct allowedips_node __rcu *root6;
	u64 seq;
};
} __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 16-bit alignment. */

void wg_allowedips_init(struct allowedips *table);
void wg_allowedips_free(struct allowedips *table, struct mutex *mutex);
@@ -56,4 +53,7 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
bool wg_allowedips_selftest(void);
#endif

int wg_allowedips_slab_init(void);
void wg_allowedips_slab_uninit(void);

#endif /* _WG_ALLOWEDIPS_H */
+16 −1
Original line number Diff line number Diff line
@@ -21,13 +21,22 @@ static int __init mod_init(void)
{
	int ret;

	ret = wg_allowedips_slab_init();
	if (ret < 0)
		goto err_allowedips;

#ifdef DEBUG
	ret = -ENOTRECOVERABLE;
	if (!wg_allowedips_selftest() || !wg_packet_counter_selftest() ||
	    !wg_ratelimiter_selftest())
		return -ENOTRECOVERABLE;
		goto err_peer;
#endif
	wg_noise_init();

	ret = wg_peer_init();
	if (ret < 0)
		goto err_peer;

	ret = wg_device_init();
	if (ret < 0)
		goto err_device;
@@ -44,6 +53,10 @@ static int __init mod_init(void)
err_netlink:
	wg_device_uninit();
err_device:
	wg_peer_uninit();
err_peer:
	wg_allowedips_slab_uninit();
err_allowedips:
	return ret;
}

@@ -51,6 +64,8 @@ static void __exit mod_exit(void)
{
	wg_genetlink_uninit();
	wg_device_uninit();
	wg_peer_uninit();
	wg_allowedips_slab_uninit();
}

module_init(mod_init);
+20 −7
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
#include <linux/rcupdate.h>
#include <linux/list.h>

static struct kmem_cache *peer_cache;
static atomic64_t peer_counter = ATOMIC64_INIT(0);

struct wg_peer *wg_peer_create(struct wg_device *wg,
@@ -29,10 +30,10 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
	if (wg->num_peers >= MAX_PEERS_PER_DEVICE)
		return ERR_PTR(ret);

	peer = kzalloc(sizeof(*peer), GFP_KERNEL);
	peer = kmem_cache_zalloc(peer_cache, GFP_KERNEL);
	if (unlikely(!peer))
		return ERR_PTR(ret);
	if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL))
	if (unlikely(dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)))
		goto err;

	peer->device = wg;
@@ -64,7 +65,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
	return peer;

err:
	kfree(peer);
	kmem_cache_free(peer_cache, peer);
	return ERR_PTR(ret);
}

@@ -88,7 +89,7 @@ static void peer_make_dead(struct wg_peer *peer)
	/* Mark as dead, so that we don't allow jumping contexts after. */
	WRITE_ONCE(peer->is_dead, true);

	/* The caller must now synchronize_rcu() for this to take effect. */
	/* The caller must now synchronize_net() for this to take effect. */
}

static void peer_remove_after_dead(struct wg_peer *peer)
@@ -160,7 +161,7 @@ void wg_peer_remove(struct wg_peer *peer)
	lockdep_assert_held(&peer->device->device_update_lock);

	peer_make_dead(peer);
	synchronize_rcu();
	synchronize_net();
	peer_remove_after_dead(peer);
}

@@ -178,7 +179,7 @@ void wg_peer_remove_all(struct wg_device *wg)
		peer_make_dead(peer);
		list_add_tail(&peer->peer_list, &dead_peers);
	}
	synchronize_rcu();
	synchronize_net();
	list_for_each_entry_safe(peer, temp, &dead_peers, peer_list)
		peer_remove_after_dead(peer);
}
@@ -193,7 +194,8 @@ static void rcu_release(struct rcu_head *rcu)
	/* The final zeroing takes care of clearing any remaining handshake key
	 * material and other potentially sensitive information.
	 */
	kfree_sensitive(peer);
	memzero_explicit(peer, sizeof(*peer));
	kmem_cache_free(peer_cache, peer);
}

static void kref_release(struct kref *refcount)
@@ -225,3 +227,14 @@ void wg_peer_put(struct wg_peer *peer)
		return;
	kref_put(&peer->refcount, kref_release);
}

int __init wg_peer_init(void)
{
	peer_cache = KMEM_CACHE(wg_peer, 0);
	return peer_cache ? 0 : -ENOMEM;
}

void wg_peer_uninit(void)
{
	kmem_cache_destroy(peer_cache);
}
Loading