Commit 6dd4142f authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'af_unix-per-netns-socket-hash'

Kuniyuki Iwashima says:

====================
af_unix: Introduce per-netns socket hash table.

This series replaces unix_socket_table with a per-netns hash table and
reduces lock contention and time on iterating over the list.

Note the 3rd-6th patches can be a single patch, but for ease of review,
they are split into small changes without breakage.

Changes:
  v3:
    6th:
      * Remove unix_table_locks from comments.
      * Remove missed spin_unlock(&unix_table_locks) in
        unix_lookup_by_ino() (kernel test robot)

  v2: https://lore.kernel.org/netdev/20220620185151.65294-1-kuniyu@amazon.com/
    3rd:
      * Update changelog
      * Remove holes from per-netns hash table structure
      * Use kvmalloc_array() instead of kmalloc() (Eric Dumazet)
      * Remove unnecessary parts in af_unix_init() (Eric Dumazet)
      * Move `err_sysctl` label into ifdef block (kernel test robot)
      * Remove struct netns_unix from struct net if CONFIG_UNIX is disabled
    4th:
      * Use spin_lock_nested() (kernel test robot)

  v1: https://lore.kernel.org/netdev/20220616234714.4291-1-kuniyu@amazon.com/


====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents ffd3018b 2f7ca90a
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -16,12 +16,11 @@ void wait_for_unix_gc(void);
struct sock *unix_get_socket(struct file *filp);
struct sock *unix_peer_get(struct sock *sk);

#define UNIX_HASH_SIZE	256
#define UNIX_HASH_MOD	(256 - 1)
#define UNIX_HASH_SIZE	(256 * 2)
#define UNIX_HASH_BITS	8

extern unsigned int unix_tot_inflight;
extern spinlock_t unix_table_locks[2 * UNIX_HASH_SIZE];
extern struct hlist_head unix_socket_table[2 * UNIX_HASH_SIZE];

struct unix_address {
	refcount_t	refcnt;
+2 −0
Original line number Diff line number Diff line
@@ -120,7 +120,9 @@ struct net {
	struct netns_core	core;
	struct netns_mib	mib;
	struct netns_packet	packet;
#if IS_ENABLED(CONFIG_UNIX)
	struct netns_unix	unx;
#endif
	struct netns_nexthop	nexthop;
	struct netns_ipv4	ipv4;
#if IS_ENABLED(CONFIG_IPV6)
+6 −0
Original line number Diff line number Diff line
@@ -5,8 +5,14 @@
#ifndef __NETNS_UNIX_H__
#define __NETNS_UNIX_H__

struct unix_table {
	spinlock_t		*locks;
	struct hlist_head	*buckets;
};

struct ctl_table_header;
struct netns_unix {
	struct unix_table	table;
	int			sysctl_max_dgram_qlen;
	struct ctl_table_header	*ctl;
};
+120 −108
Original line number Diff line number Diff line
@@ -118,14 +118,10 @@

#include "scm.h"

spinlock_t unix_table_locks[2 * UNIX_HASH_SIZE];
EXPORT_SYMBOL_GPL(unix_table_locks);
struct hlist_head unix_socket_table[2 * UNIX_HASH_SIZE];
EXPORT_SYMBOL_GPL(unix_socket_table);
static atomic_long_t unix_nr_socks;

/* SMP locking strategy:
 *    hash table is protected with spinlock unix_table_locks
 *    hash table is protected with spinlock.
 *    each socket state is protected by separate spinlock.
 */

@@ -137,12 +133,12 @@ static unsigned int unix_unbound_hash(struct sock *sk)
	hash ^= hash >> 8;
	hash ^= sk->sk_type;

	return UNIX_HASH_SIZE + (hash & (UNIX_HASH_SIZE - 1));
	return UNIX_HASH_MOD + 1 + (hash & UNIX_HASH_MOD);
}

static unsigned int unix_bsd_hash(struct inode *i)
{
	return i->i_ino & (UNIX_HASH_SIZE - 1);
	return i->i_ino & UNIX_HASH_MOD;
}

static unsigned int unix_abstract_hash(struct sockaddr_un *sunaddr,
@@ -155,26 +151,28 @@ static unsigned int unix_abstract_hash(struct sockaddr_un *sunaddr,
	hash ^= hash >> 8;
	hash ^= type;

	return hash & (UNIX_HASH_SIZE - 1);
	return hash & UNIX_HASH_MOD;
}

static void unix_table_double_lock(unsigned int hash1, unsigned int hash2)
static void unix_table_double_lock(struct net *net,
				   unsigned int hash1, unsigned int hash2)
{
	/* hash1 and hash2 is never the same because
	 * one is between 0 and UNIX_HASH_SIZE - 1, and
	 * another is between UNIX_HASH_SIZE and UNIX_HASH_SIZE * 2.
	 * one is between 0 and UNIX_HASH_MOD, and
	 * another is between UNIX_HASH_MOD + 1 and UNIX_HASH_SIZE - 1.
	 */
	if (hash1 > hash2)
		swap(hash1, hash2);

	spin_lock(&unix_table_locks[hash1]);
	spin_lock_nested(&unix_table_locks[hash2], SINGLE_DEPTH_NESTING);
	spin_lock(&net->unx.table.locks[hash1]);
	spin_lock_nested(&net->unx.table.locks[hash2], SINGLE_DEPTH_NESTING);
}

static void unix_table_double_unlock(unsigned int hash1, unsigned int hash2)
static void unix_table_double_unlock(struct net *net,
				     unsigned int hash1, unsigned int hash2)
{
	spin_unlock(&unix_table_locks[hash1]);
	spin_unlock(&unix_table_locks[hash2]);
	spin_unlock(&net->unx.table.locks[hash1]);
	spin_unlock(&net->unx.table.locks[hash2]);
}

#ifdef CONFIG_SECURITY_NETWORK
@@ -300,34 +298,34 @@ static void __unix_remove_socket(struct sock *sk)
	sk_del_node_init(sk);
}

static void __unix_insert_socket(struct sock *sk)
static void __unix_insert_socket(struct net *net, struct sock *sk)
{
	DEBUG_NET_WARN_ON_ONCE(!sk_unhashed(sk));
	sk_add_node(sk, &unix_socket_table[sk->sk_hash]);
	sk_add_node(sk, &net->unx.table.buckets[sk->sk_hash]);
}

static void __unix_set_addr_hash(struct sock *sk, struct unix_address *addr,
				 unsigned int hash)
static void __unix_set_addr_hash(struct net *net, struct sock *sk,
				 struct unix_address *addr, unsigned int hash)
{
	__unix_remove_socket(sk);
	smp_store_release(&unix_sk(sk)->addr, addr);

	sk->sk_hash = hash;
	__unix_insert_socket(sk);
	__unix_insert_socket(net, sk);
}

static void unix_remove_socket(struct sock *sk)
static void unix_remove_socket(struct net *net, struct sock *sk)
{
	spin_lock(&unix_table_locks[sk->sk_hash]);
	spin_lock(&net->unx.table.locks[sk->sk_hash]);
	__unix_remove_socket(sk);
	spin_unlock(&unix_table_locks[sk->sk_hash]);
	spin_unlock(&net->unx.table.locks[sk->sk_hash]);
}

static void unix_insert_unbound_socket(struct sock *sk)
static void unix_insert_unbound_socket(struct net *net, struct sock *sk)
{
	spin_lock(&unix_table_locks[sk->sk_hash]);
	__unix_insert_socket(sk);
	spin_unlock(&unix_table_locks[sk->sk_hash]);
	spin_lock(&net->unx.table.locks[sk->sk_hash]);
	__unix_insert_socket(net, sk);
	spin_unlock(&net->unx.table.locks[sk->sk_hash]);
}

static struct sock *__unix_find_socket_byname(struct net *net,
@@ -336,12 +334,9 @@ static struct sock *__unix_find_socket_byname(struct net *net,
{
	struct sock *s;

	sk_for_each(s, &unix_socket_table[hash]) {
	sk_for_each(s, &net->unx.table.buckets[hash]) {
		struct unix_sock *u = unix_sk(s);

		if (!net_eq(sock_net(s), net))
			continue;

		if (u->addr->len == len &&
		    !memcmp(u->addr->name, sunname, len))
			return s;
@@ -355,30 +350,30 @@ static inline struct sock *unix_find_socket_byname(struct net *net,
{
	struct sock *s;

	spin_lock(&unix_table_locks[hash]);
	spin_lock(&net->unx.table.locks[hash]);
	s = __unix_find_socket_byname(net, sunname, len, hash);
	if (s)
		sock_hold(s);
	spin_unlock(&unix_table_locks[hash]);
	spin_unlock(&net->unx.table.locks[hash]);
	return s;
}

static struct sock *unix_find_socket_byinode(struct inode *i)
static struct sock *unix_find_socket_byinode(struct net *net, struct inode *i)
{
	unsigned int hash = unix_bsd_hash(i);
	struct sock *s;

	spin_lock(&unix_table_locks[hash]);
	sk_for_each(s, &unix_socket_table[hash]) {
	spin_lock(&net->unx.table.locks[hash]);
	sk_for_each(s, &net->unx.table.buckets[hash]) {
		struct dentry *dentry = unix_sk(s)->path.dentry;

		if (dentry && d_backing_inode(dentry) == i) {
			sock_hold(s);
			spin_unlock(&unix_table_locks[hash]);
			spin_unlock(&net->unx.table.locks[hash]);
			return s;
		}
	}
	spin_unlock(&unix_table_locks[hash]);
	spin_unlock(&net->unx.table.locks[hash]);
	return NULL;
}

@@ -576,12 +571,12 @@ static void unix_sock_destructor(struct sock *sk)
static void unix_release_sock(struct sock *sk, int embrion)
{
	struct unix_sock *u = unix_sk(sk);
	struct path path;
	struct sock *skpair;
	struct sk_buff *skb;
	struct path path;
	int state;

	unix_remove_socket(sk);
	unix_remove_socket(sock_net(sk), sk);

	/* Clear state */
	unix_state_lock(sk);
@@ -930,9 +925,9 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern,
	init_waitqueue_head(&u->peer_wait);
	init_waitqueue_func_entry(&u->peer_wake, unix_dgram_peer_wake_relay);
	memset(&u->scm_stat, 0, sizeof(struct scm_stat));
	unix_insert_unbound_socket(sk);
	unix_insert_unbound_socket(net, sk);

	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
	sock_prot_inuse_add(net, sk->sk_prot, 1);

	return sk;

@@ -1015,7 +1010,7 @@ static struct sock *unix_find_bsd(struct net *net, struct sockaddr_un *sunaddr,
	if (!S_ISSOCK(inode->i_mode))
		goto path_put;

	sk = unix_find_socket_byinode(inode);
	sk = unix_find_socket_byinode(net, inode);
	if (!sk)
		goto path_put;

@@ -1074,6 +1069,7 @@ static int unix_autobind(struct sock *sk)
{
	unsigned int new_hash, old_hash = sk->sk_hash;
	struct unix_sock *u = unix_sk(sk);
	struct net *net = sock_net(sk);
	struct unix_address *addr;
	u32 lastnum, ordernum;
	int err;
@@ -1102,11 +1098,10 @@ static int unix_autobind(struct sock *sk)
	sprintf(addr->name->sun_path + 1, "%05x", ordernum);

	new_hash = unix_abstract_hash(addr->name, addr->len, sk->sk_type);
	unix_table_double_lock(old_hash, new_hash);
	unix_table_double_lock(net, old_hash, new_hash);

	if (__unix_find_socket_byname(sock_net(sk), addr->name, addr->len,
				      new_hash)) {
		unix_table_double_unlock(old_hash, new_hash);
	if (__unix_find_socket_byname(net, addr->name, addr->len, new_hash)) {
		unix_table_double_unlock(net, old_hash, new_hash);

		/* __unix_find_socket_byname() may take long time if many names
		 * are already in use.
@@ -1123,8 +1118,8 @@ static int unix_autobind(struct sock *sk)
		goto retry;
	}

	__unix_set_addr_hash(sk, addr, new_hash);
	unix_table_double_unlock(old_hash, new_hash);
	__unix_set_addr_hash(net, sk, addr, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	err = 0;

out:	mutex_unlock(&u->bindlock);
@@ -1138,6 +1133,7 @@ static int unix_bind_bsd(struct sock *sk, struct sockaddr_un *sunaddr,
	       (SOCK_INODE(sk->sk_socket)->i_mode & ~current_umask());
	unsigned int new_hash, old_hash = sk->sk_hash;
	struct unix_sock *u = unix_sk(sk);
	struct net *net = sock_net(sk);
	struct user_namespace *ns; // barf...
	struct unix_address *addr;
	struct dentry *dentry;
@@ -1178,11 +1174,11 @@ static int unix_bind_bsd(struct sock *sk, struct sockaddr_un *sunaddr,
		goto out_unlock;

	new_hash = unix_bsd_hash(d_backing_inode(dentry));
	unix_table_double_lock(old_hash, new_hash);
	unix_table_double_lock(net, old_hash, new_hash);
	u->path.mnt = mntget(parent.mnt);
	u->path.dentry = dget(dentry);
	__unix_set_addr_hash(sk, addr, new_hash);
	unix_table_double_unlock(old_hash, new_hash);
	__unix_set_addr_hash(net, sk, addr, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	mutex_unlock(&u->bindlock);
	done_path_create(&parent, dentry);
	return 0;
@@ -1205,6 +1201,7 @@ static int unix_bind_abstract(struct sock *sk, struct sockaddr_un *sunaddr,
{
	unsigned int new_hash, old_hash = sk->sk_hash;
	struct unix_sock *u = unix_sk(sk);
	struct net *net = sock_net(sk);
	struct unix_address *addr;
	int err;

@@ -1222,19 +1219,18 @@ static int unix_bind_abstract(struct sock *sk, struct sockaddr_un *sunaddr,
	}

	new_hash = unix_abstract_hash(addr->name, addr->len, sk->sk_type);
	unix_table_double_lock(old_hash, new_hash);
	unix_table_double_lock(net, old_hash, new_hash);

	if (__unix_find_socket_byname(sock_net(sk), addr->name, addr->len,
				      new_hash))
	if (__unix_find_socket_byname(net, addr->name, addr->len, new_hash))
		goto out_spin;

	__unix_set_addr_hash(sk, addr, new_hash);
	unix_table_double_unlock(old_hash, new_hash);
	__unix_set_addr_hash(net, sk, addr, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	mutex_unlock(&u->bindlock);
	return 0;

out_spin:
	unix_table_double_unlock(old_hash, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	err = -EADDRINUSE;
out_mutex:
	mutex_unlock(&u->bindlock);
@@ -1293,9 +1289,8 @@ static void unix_state_double_unlock(struct sock *sk1, struct sock *sk2)
static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
			      int alen, int flags)
{
	struct sock *sk = sock->sk;
	struct net *net = sock_net(sk);
	struct sockaddr_un *sunaddr = (struct sockaddr_un *)addr;
	struct sock *sk = sock->sk;
	struct sock *other;
	int err;

@@ -1316,7 +1311,7 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
		}

restart:
		other = unix_find_other(net, sunaddr, alen, sock->type);
		other = unix_find_other(sock_net(sk), sunaddr, alen, sock->type);
		if (IS_ERR(other)) {
			err = PTR_ERR(other);
			goto out;
@@ -1404,15 +1399,13 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
			       int addr_len, int flags)
{
	struct sockaddr_un *sunaddr = (struct sockaddr_un *)uaddr;
	struct sock *sk = sock->sk;
	struct net *net = sock_net(sk);
	struct sock *sk = sock->sk, *newsk = NULL, *other = NULL;
	struct unix_sock *u = unix_sk(sk), *newu, *otheru;
	struct sock *newsk = NULL;
	struct sock *other = NULL;
	struct net *net = sock_net(sk);
	struct sk_buff *skb = NULL;
	int st;
	int err;
	long timeo;
	int err;
	int st;

	err = unix_validate_addr(sunaddr, addr_len);
	if (err)
@@ -1432,7 +1425,7 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
	 */

	/* create new sock for complete connection */
	newsk = unix_create1(sock_net(sk), NULL, 0, sock->type);
	newsk = unix_create1(net, NULL, 0, sock->type);
	if (IS_ERR(newsk)) {
		err = PTR_ERR(newsk);
		newsk = NULL;
@@ -1541,9 +1534,9 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
	 *
	 * The contents of *(otheru->addr) and otheru->path
	 * are seen fully set up here, since we have found
	 * otheru in hash under unix_table_locks.  Insertion
	 * into the hash chain we'd found it in had been done
	 * in an earlier critical area protected by unix_table_locks,
	 * otheru in hash under its lock.  Insertion into the
	 * hash chain we'd found it in had been done in an
	 * earlier critical area protected by the chain's lock,
	 * the same one where we'd set *(otheru->addr) contents,
	 * as well as otheru->path and otheru->addr itself.
	 *
@@ -1840,17 +1833,15 @@ static void scm_stat_del(struct sock *sk, struct sk_buff *skb)
static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
			      size_t len)
{
	struct sock *sk = sock->sk;
	struct net *net = sock_net(sk);
	struct unix_sock *u = unix_sk(sk);
	DECLARE_SOCKADDR(struct sockaddr_un *, sunaddr, msg->msg_name);
	struct sock *other = NULL;
	int err;
	struct sk_buff *skb;
	long timeo;
	struct sock *sk = sock->sk, *other = NULL;
	struct unix_sock *u = unix_sk(sk);
	struct scm_cookie scm;
	struct sk_buff *skb;
	int data_len = 0;
	int sk_locked;
	long timeo;
	int err;

	wait_for_unix_gc();
	err = scm_send(sock, msg, &scm, false);
@@ -1917,7 +1908,7 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
		if (sunaddr == NULL)
			goto out_free;

		other = unix_find_other(net, sunaddr, msg->msg_namelen,
		other = unix_find_other(sock_net(sk), sunaddr, msg->msg_namelen,
					sk->sk_type);
		if (IS_ERR(other)) {
			err = PTR_ERR(other);
@@ -3226,12 +3217,11 @@ static struct sock *unix_from_bucket(struct seq_file *seq, loff_t *pos)
{
	unsigned long offset = get_offset(*pos);
	unsigned long bucket = get_bucket(*pos);
	struct sock *sk;
	unsigned long count = 0;
	struct sock *sk;

	for (sk = sk_head(&unix_socket_table[bucket]); sk; sk = sk_next(sk)) {
		if (sock_net(sk) != seq_file_net(seq))
			continue;
	for (sk = sk_head(&seq_file_net(seq)->unx.table.buckets[bucket]);
	     sk; sk = sk_next(sk)) {
		if (++count == offset)
			break;
	}
@@ -3242,16 +3232,17 @@ static struct sock *unix_from_bucket(struct seq_file *seq, loff_t *pos)
static struct sock *unix_get_first(struct seq_file *seq, loff_t *pos)
{
	unsigned long bucket = get_bucket(*pos);
	struct net *net = seq_file_net(seq);
	struct sock *sk;

	while (bucket < ARRAY_SIZE(unix_socket_table)) {
		spin_lock(&unix_table_locks[bucket]);
	while (bucket < UNIX_HASH_SIZE) {
		spin_lock(&net->unx.table.locks[bucket]);

		sk = unix_from_bucket(seq, pos);
		if (sk)
			return sk;

		spin_unlock(&unix_table_locks[bucket]);
		spin_unlock(&net->unx.table.locks[bucket]);

		*pos = set_bucket_offset(++bucket, 1);
	}
@@ -3264,11 +3255,12 @@ static struct sock *unix_get_next(struct seq_file *seq, struct sock *sk,
{
	unsigned long bucket = get_bucket(*pos);

	for (sk = sk_next(sk); sk; sk = sk_next(sk))
		if (sock_net(sk) == seq_file_net(seq))
	sk = sk_next(sk);
	if (sk)
		return sk;

	spin_unlock(&unix_table_locks[bucket]);

	spin_unlock(&seq_file_net(seq)->unx.table.locks[bucket]);

	*pos = set_bucket_offset(++bucket, 1);

@@ -3298,7 +3290,7 @@ static void unix_seq_stop(struct seq_file *seq, void *v)
	struct sock *sk = v;

	if (sk)
		spin_unlock(&unix_table_locks[sk->sk_hash]);
		spin_unlock(&seq_file_net(seq)->unx.table.locks[sk->sk_hash]);
}

static int unix_seq_show(struct seq_file *seq, void *v)
@@ -3323,7 +3315,7 @@ static int unix_seq_show(struct seq_file *seq, void *v)
			(s->sk_state == TCP_ESTABLISHED ? SS_CONNECTING : SS_DISCONNECTING),
			sock_i_ino(s));

		if (u->addr) {	// under unix_table_locks here
		if (u->addr) {	// under a hash table lock here
			int i, len;
			seq_putc(seq, ' ');

@@ -3393,9 +3385,6 @@ static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk)
	iter->batch[iter->end_sk++] = start_sk;

	for (sk = sk_next(start_sk); sk; sk = sk_next(sk)) {
		if (sock_net(sk) != seq_file_net(seq))
			continue;

		if (iter->end_sk < iter->max_sk) {
			sock_hold(sk);
			iter->batch[iter->end_sk++] = sk;
@@ -3404,7 +3393,7 @@ static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk)
		expected++;
	}

	spin_unlock(&unix_table_locks[start_sk->sk_hash]);
	spin_unlock(&seq_file_net(seq)->unx.table.locks[start_sk->sk_hash]);

	return expected;
}
@@ -3564,7 +3553,7 @@ static const struct net_proto_family unix_family_ops = {

static int __net_init unix_net_init(struct net *net)
{
	int error = -ENOMEM;
	int i;

	net->unx.sysctl_max_dgram_qlen = 10;
	if (unix_sysctl_register(net))
@@ -3572,18 +3561,44 @@ static int __net_init unix_net_init(struct net *net)

#ifdef CONFIG_PROC_FS
	if (!proc_create_net("unix", 0, net->proc_net, &unix_seq_ops,
			sizeof(struct seq_net_private))) {
		unix_sysctl_unregister(net);
		goto out;
			     sizeof(struct seq_net_private)))
		goto err_sysctl;
#endif

	net->unx.table.locks = kvmalloc_array(UNIX_HASH_SIZE,
					      sizeof(spinlock_t), GFP_KERNEL);
	if (!net->unx.table.locks)
		goto err_proc;

	net->unx.table.buckets = kvmalloc_array(UNIX_HASH_SIZE,
						sizeof(struct hlist_head),
						GFP_KERNEL);
	if (!net->unx.table.buckets)
		goto free_locks;

	for (i = 0; i < UNIX_HASH_SIZE; i++) {
		spin_lock_init(&net->unx.table.locks[i]);
		INIT_HLIST_HEAD(&net->unx.table.buckets[i]);
	}

	return 0;

free_locks:
	kvfree(net->unx.table.locks);
err_proc:
#ifdef CONFIG_PROC_FS
	remove_proc_entry("unix", net->proc_net);
err_sysctl:
#endif
	error = 0;
	unix_sysctl_unregister(net);
out:
	return error;
	return -ENOMEM;
}

static void __net_exit unix_net_exit(struct net *net)
{
	kvfree(net->unx.table.buckets);
	kvfree(net->unx.table.locks);
	unix_sysctl_unregister(net);
	remove_proc_entry("unix", net->proc_net);
}
@@ -3667,13 +3682,10 @@ static void __init bpf_iter_register(void)

static int __init af_unix_init(void)
{
	int i, rc = -1;
	int rc = -1;

	BUILD_BUG_ON(sizeof(struct unix_skb_parms) > sizeof_field(struct sk_buff, cb));

	for (i = 0; i < 2 * UNIX_HASH_SIZE; i++)
		spin_lock_init(&unix_table_locks[i]);

	rc = proto_register(&unix_dgram_proto, 1);
	if (rc != 0) {
		pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
+22 −27
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@

static int sk_diag_dump_name(struct sock *sk, struct sk_buff *nlskb)
{
	/* might or might not have unix_table_locks */
	/* might or might not have a hash table lock */
	struct unix_address *addr = smp_load_acquire(&unix_sk(sk)->addr);

	if (!addr)
@@ -195,25 +195,21 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct unix_diag_r

static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
{
	struct unix_diag_req *req;
	int num, s_num, slot, s_slot;
	struct net *net = sock_net(skb->sk);
	int num, s_num, slot, s_slot;
	struct unix_diag_req *req;

	req = nlmsg_data(cb->nlh);

	s_slot = cb->args[0];
	num = s_num = cb->args[1];

	for (slot = s_slot;
	     slot < ARRAY_SIZE(unix_socket_table);
	     s_num = 0, slot++) {
	for (slot = s_slot; slot < UNIX_HASH_SIZE; s_num = 0, slot++) {
		struct sock *sk;

		num = 0;
		spin_lock(&unix_table_locks[slot]);
		sk_for_each(sk, &unix_socket_table[slot]) {
			if (!net_eq(sock_net(sk), net))
				continue;
		spin_lock(&net->unx.table.locks[slot]);
		sk_for_each(sk, &net->unx.table.buckets[slot]) {
			if (num < s_num)
				goto next;
			if (!(req->udiag_states & (1 << sk->sk_state)))
@@ -222,13 +218,13 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
					 NETLINK_CB(cb->skb).portid,
					 cb->nlh->nlmsg_seq,
					 NLM_F_MULTI) < 0) {
				spin_unlock(&unix_table_locks[slot]);
				spin_unlock(&net->unx.table.locks[slot]);
				goto done;
			}
next:
			num++;
		}
		spin_unlock(&unix_table_locks[slot]);
		spin_unlock(&net->unx.table.locks[slot]);
	}
done:
	cb->args[0] = slot;
@@ -237,20 +233,21 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
	return skb->len;
}

static struct sock *unix_lookup_by_ino(unsigned int ino)
static struct sock *unix_lookup_by_ino(struct net *net, unsigned int ino)
{
	struct sock *sk;
	int i;

	for (i = 0; i < ARRAY_SIZE(unix_socket_table); i++) {
		spin_lock(&unix_table_locks[i]);
		sk_for_each(sk, &unix_socket_table[i])
	for (i = 0; i < UNIX_HASH_SIZE; i++) {
		spin_lock(&net->unx.table.locks[i]);
		sk_for_each(sk, &net->unx.table.buckets[i]) {
			if (ino == sock_i_ino(sk)) {
				sock_hold(sk);
				spin_unlock(&unix_table_locks[i]);
				spin_unlock(&net->unx.table.locks[i]);
				return sk;
			}
		spin_unlock(&unix_table_locks[i]);
		}
		spin_unlock(&net->unx.table.locks[i]);
	}
	return NULL;
}
@@ -259,21 +256,20 @@ static int unix_diag_get_exact(struct sk_buff *in_skb,
			       const struct nlmsghdr *nlh,
			       struct unix_diag_req *req)
{
	int err = -EINVAL;
	struct sock *sk;
	struct sk_buff *rep;
	unsigned int extra_len;
	struct net *net = sock_net(in_skb->sk);
	unsigned int extra_len;
	struct sk_buff *rep;
	struct sock *sk;
	int err;

	err = -EINVAL;
	if (req->udiag_ino == 0)
		goto out_nosk;

	sk = unix_lookup_by_ino(req->udiag_ino);
	sk = unix_lookup_by_ino(net, req->udiag_ino);
	err = -ENOENT;
	if (sk == NULL)
		goto out_nosk;
	if (!net_eq(sock_net(sk), net))
		goto out;

	err = sock_diag_check_cookie(sk, req->udiag_cookie);
	if (err)
@@ -308,7 +304,6 @@ static int unix_diag_get_exact(struct sk_buff *in_skb,
static int unix_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
{
	int hdrlen = sizeof(struct unix_diag_req);
	struct net *net = sock_net(skb->sk);

	if (nlmsg_len(h) < hdrlen)
		return -EINVAL;
@@ -317,7 +312,7 @@ static int unix_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
		struct netlink_dump_control c = {
			.dump = unix_diag_dump,
		};
		return netlink_dump_start(net->diag_nlsk, skb, h, &c);
		return netlink_dump_start(sock_net(skb->sk)->diag_nlsk, skb, h, &c);
	} else
		return unix_diag_get_exact(skb, h, nlmsg_data(h));
}