Commit 4101971a authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'l2tp-races'

Cong Wang says:

====================
l2tp: fix race conditions in l2tp_tunnel_register()

This patchset contains two patches, the first one is a preparation for
the second one which is the actual fix. Please find more details in
each patch description.

I have ran the l2tp test (https://github.com/katalix/l2tp-ktest

),
all test cases are passed.

v3: preserve EEXIST errno for user-space
v2: move IDR allocation to l2tp_tunnel_register()
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 3a415d59 0b2c5972
Loading
Loading
Loading
Loading
+52 −53
Original line number Diff line number Diff line
@@ -104,9 +104,9 @@ static struct workqueue_struct *l2tp_wq;
/* per-net private data for this module */
static unsigned int l2tp_net_id;
struct l2tp_net {
	struct list_head l2tp_tunnel_list;
	/* Lock for write access to l2tp_tunnel_list */
	spinlock_t l2tp_tunnel_list_lock;
	/* Lock for write access to l2tp_tunnel_idr */
	spinlock_t l2tp_tunnel_idr_lock;
	struct idr l2tp_tunnel_idr;
	struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2];
	/* Lock for write access to l2tp_session_hlist */
	spinlock_t l2tp_session_hlist_lock;
@@ -208,14 +208,11 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
	struct l2tp_tunnel *tunnel;

	rcu_read_lock_bh();
	list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
		if (tunnel->tunnel_id == tunnel_id &&
		    refcount_inc_not_zero(&tunnel->ref_count)) {
	tunnel = idr_find(&pn->l2tp_tunnel_idr, tunnel_id);
	if (tunnel && refcount_inc_not_zero(&tunnel->ref_count)) {
		rcu_read_unlock_bh();

		return tunnel;
	}
	}
	rcu_read_unlock_bh();

	return NULL;
@@ -224,13 +221,14 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_get);

struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth)
{
	const struct l2tp_net *pn = l2tp_pernet(net);
	struct l2tp_net *pn = l2tp_pernet(net);
	unsigned long tunnel_id, tmp;
	struct l2tp_tunnel *tunnel;
	int count = 0;

	rcu_read_lock_bh();
	list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
		if (++count > nth &&
	idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
		if (tunnel && ++count > nth &&
		    refcount_inc_not_zero(&tunnel->ref_count)) {
			rcu_read_unlock_bh();
			return tunnel;
@@ -1043,7 +1041,7 @@ static int l2tp_xmit_core(struct l2tp_session *session, struct sk_buff *skb, uns
	IPCB(skb)->flags &= ~(IPSKB_XFRM_TUNNEL_SIZE | IPSKB_XFRM_TRANSFORMED | IPSKB_REROUTED);
	nf_reset_ct(skb);

	bh_lock_sock(sk);
	bh_lock_sock_nested(sk);
	if (sock_owned_by_user(sk)) {
		kfree_skb(skb);
		ret = NET_XMIT_DROP;
@@ -1227,6 +1225,15 @@ static void l2tp_udp_encap_destroy(struct sock *sk)
		l2tp_tunnel_delete(tunnel);
}

static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel)
{
	struct l2tp_net *pn = l2tp_pernet(net);

	spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
	idr_remove(&pn->l2tp_tunnel_idr, tunnel->tunnel_id);
	spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
}

/* Workqueue tunnel deletion function */
static void l2tp_tunnel_del_work(struct work_struct *work)
{
@@ -1234,7 +1241,6 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
						  del_work);
	struct sock *sk = tunnel->sock;
	struct socket *sock = sk->sk_socket;
	struct l2tp_net *pn;

	l2tp_tunnel_closeall(tunnel);

@@ -1248,12 +1254,7 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
		}
	}

	/* Remove the tunnel struct from the tunnel list */
	pn = l2tp_pernet(tunnel->l2tp_net);
	spin_lock_bh(&pn->l2tp_tunnel_list_lock);
	list_del_rcu(&tunnel->list);
	spin_unlock_bh(&pn->l2tp_tunnel_list_lock);

	l2tp_tunnel_remove(tunnel->l2tp_net, tunnel);
	/* drop initial ref */
	l2tp_tunnel_dec_refcount(tunnel);

@@ -1384,8 +1385,6 @@ static int l2tp_tunnel_sock_create(struct net *net,
	return err;
}

static struct lock_class_key l2tp_socket_class;

int l2tp_tunnel_create(int fd, int version, u32 tunnel_id, u32 peer_tunnel_id,
		       struct l2tp_tunnel_cfg *cfg, struct l2tp_tunnel **tunnelp)
{
@@ -1455,12 +1454,19 @@ static int l2tp_validate_socket(const struct sock *sk, const struct net *net,
int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
			 struct l2tp_tunnel_cfg *cfg)
{
	struct l2tp_tunnel *tunnel_walk;
	struct l2tp_net *pn;
	struct l2tp_net *pn = l2tp_pernet(net);
	u32 tunnel_id = tunnel->tunnel_id;
	struct socket *sock;
	struct sock *sk;
	int ret;

	spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
	ret = idr_alloc_u32(&pn->l2tp_tunnel_idr, NULL, &tunnel_id, tunnel_id,
			    GFP_ATOMIC);
	spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
	if (ret)
		return ret == -ENOSPC ? -EEXIST : ret;

	if (tunnel->fd < 0) {
		ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id,
					      tunnel->peer_tunnel_id, cfg,
@@ -1474,31 +1480,16 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
	}

	sk = sock->sk;
	lock_sock(sk);
	write_lock_bh(&sk->sk_callback_lock);
	ret = l2tp_validate_socket(sk, net, tunnel->encap);
	if (ret < 0)
	if (ret < 0) {
		release_sock(sk);
		goto err_inval_sock;
	}
	rcu_assign_sk_user_data(sk, tunnel);
	write_unlock_bh(&sk->sk_callback_lock);

	tunnel->l2tp_net = net;
	pn = l2tp_pernet(net);

	sock_hold(sk);
	tunnel->sock = sk;

	spin_lock_bh(&pn->l2tp_tunnel_list_lock);
	list_for_each_entry(tunnel_walk, &pn->l2tp_tunnel_list, list) {
		if (tunnel_walk->tunnel_id == tunnel->tunnel_id) {
			spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
			sock_put(sk);
			ret = -EEXIST;
			goto err_sock;
		}
	}
	list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list);
	spin_unlock_bh(&pn->l2tp_tunnel_list_lock);

	if (tunnel->encap == L2TP_ENCAPTYPE_UDP) {
		struct udp_tunnel_sock_cfg udp_cfg = {
			.sk_user_data = tunnel,
@@ -1512,9 +1503,16 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,

	tunnel->old_sk_destruct = sk->sk_destruct;
	sk->sk_destruct = &l2tp_tunnel_destruct;
	lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class,
				   "l2tp_sock");
	sk->sk_allocation = GFP_ATOMIC;
	release_sock(sk);

	sock_hold(sk);
	tunnel->sock = sk;
	tunnel->l2tp_net = net;

	spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
	idr_replace(&pn->l2tp_tunnel_idr, tunnel, tunnel->tunnel_id);
	spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);

	trace_register_tunnel(tunnel);

@@ -1523,9 +1521,6 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,

	return 0;

err_sock:
	write_lock_bh(&sk->sk_callback_lock);
	rcu_assign_sk_user_data(sk, NULL);
err_inval_sock:
	write_unlock_bh(&sk->sk_callback_lock);

@@ -1534,6 +1529,7 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
	else
		sockfd_put(sock);
err:
	l2tp_tunnel_remove(net, tunnel);
	return ret;
}
EXPORT_SYMBOL_GPL(l2tp_tunnel_register);
@@ -1647,8 +1643,8 @@ static __net_init int l2tp_init_net(struct net *net)
	struct l2tp_net *pn = net_generic(net, l2tp_net_id);
	int hash;

	INIT_LIST_HEAD(&pn->l2tp_tunnel_list);
	spin_lock_init(&pn->l2tp_tunnel_list_lock);
	idr_init(&pn->l2tp_tunnel_idr);
	spin_lock_init(&pn->l2tp_tunnel_idr_lock);

	for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
		INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]);
@@ -1662,10 +1658,12 @@ static __net_exit void l2tp_exit_net(struct net *net)
{
	struct l2tp_net *pn = l2tp_pernet(net);
	struct l2tp_tunnel *tunnel = NULL;
	unsigned long tunnel_id, tmp;
	int hash;

	rcu_read_lock_bh();
	list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
	idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
		if (tunnel)
			l2tp_tunnel_delete(tunnel);
	}
	rcu_read_unlock_bh();
@@ -1676,6 +1674,7 @@ static __net_exit void l2tp_exit_net(struct net *net)

	for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
		WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash]));
	idr_destroy(&pn->l2tp_tunnel_idr);
}

static struct pernet_operations l2tp_net_ops = {