Commit 712d4793 authored by Alexei Starovoitov's avatar Alexei Starovoitov
Browse files

Merge branch 'bpf: Batching iter for AF_UNIX sockets.'



Kuniyuki Iwashima says:

====================

Last year the commit afd20b92 ("af_unix: Replace the big lock with
small locks.") landed on bpf-next.  Now we can use a batching algorithm
for AF_UNIX bpf iter as TCP bpf iter.

Changelog:
- Add the 1st patch.
- Call unix_get_first() in .start()/.next() to always acquire a lock in
  each iteration in the 2nd patch.
====================

Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents 2a1aff60 a796966b
Loading
Loading
Loading
Loading
+222 −28
Original line number Diff line number Diff line
@@ -3240,49 +3240,58 @@ static struct sock *unix_from_bucket(struct seq_file *seq, loff_t *pos)
	return sk;
}

static struct sock *unix_next_socket(struct seq_file *seq,
				     struct sock *sk,
				     loff_t *pos)
static struct sock *unix_get_first(struct seq_file *seq, loff_t *pos)
{
	unsigned long bucket = get_bucket(*pos);
	struct sock *sk;

	while (sk > (struct sock *)SEQ_START_TOKEN) {
		sk = sk_next(sk);
		if (!sk)
			goto next_bucket;
		if (sock_net(sk) == seq_file_net(seq))
			return sk;
	}

	do {
	while (bucket < ARRAY_SIZE(unix_socket_table)) {
		spin_lock(&unix_table_locks[bucket]);

		sk = unix_from_bucket(seq, pos);
		if (sk)
			return sk;

next_bucket:
		spin_unlock(&unix_table_locks[bucket++]);
		*pos = set_bucket_offset(bucket, 1);
	} while (bucket < ARRAY_SIZE(unix_socket_table));
		spin_unlock(&unix_table_locks[bucket]);

		*pos = set_bucket_offset(++bucket, 1);
	}

	return NULL;
}

static struct sock *unix_get_next(struct seq_file *seq, struct sock *sk,
				  loff_t *pos)
{
	unsigned long bucket = get_bucket(*pos);

	for (sk = sk_next(sk); sk; sk = sk_next(sk))
		if (sock_net(sk) == seq_file_net(seq))
			return sk;

	spin_unlock(&unix_table_locks[bucket]);

	*pos = set_bucket_offset(++bucket, 1);

	return unix_get_first(seq, pos);
}

static void *unix_seq_start(struct seq_file *seq, loff_t *pos)
{
	if (!*pos)
		return SEQ_START_TOKEN;

	if (get_bucket(*pos) >= ARRAY_SIZE(unix_socket_table))
		return NULL;

	return unix_next_socket(seq, NULL, pos);
	return unix_get_first(seq, pos);
}

static void *unix_seq_next(struct seq_file *seq, void *v, loff_t *pos)
{
	++*pos;
	return unix_next_socket(seq, v, pos);

	if (v == SEQ_START_TOKEN)
		return unix_get_first(seq, pos);

	return unix_get_next(seq, v, pos);
}

static void unix_seq_stop(struct seq_file *seq, void *v)
@@ -3347,6 +3356,15 @@ static const struct seq_operations unix_seq_ops = {
};

#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL)
struct bpf_unix_iter_state {
	struct seq_net_private p;
	unsigned int cur_sk;
	unsigned int end_sk;
	unsigned int max_sk;
	struct sock **batch;
	bool st_bucket_done;
};

struct bpf_iter__unix {
	__bpf_md_ptr(struct bpf_iter_meta *, meta);
	__bpf_md_ptr(struct unix_sock *, unix_sk);
@@ -3365,24 +3383,156 @@ static int unix_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
	return bpf_iter_run_prog(prog, &ctx);
}

static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk)

{
	struct bpf_unix_iter_state *iter = seq->private;
	unsigned int expected = 1;
	struct sock *sk;

	sock_hold(start_sk);
	iter->batch[iter->end_sk++] = start_sk;

	for (sk = sk_next(start_sk); sk; sk = sk_next(sk)) {
		if (sock_net(sk) != seq_file_net(seq))
			continue;

		if (iter->end_sk < iter->max_sk) {
			sock_hold(sk);
			iter->batch[iter->end_sk++] = sk;
		}

		expected++;
	}

	spin_unlock(&unix_table_locks[start_sk->sk_hash]);

	return expected;
}

static void bpf_iter_unix_put_batch(struct bpf_unix_iter_state *iter)
{
	while (iter->cur_sk < iter->end_sk)
		sock_put(iter->batch[iter->cur_sk++]);
}

static int bpf_iter_unix_realloc_batch(struct bpf_unix_iter_state *iter,
				       unsigned int new_batch_sz)
{
	struct sock **new_batch;

	new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz,
			     GFP_USER | __GFP_NOWARN);
	if (!new_batch)
		return -ENOMEM;

	bpf_iter_unix_put_batch(iter);
	kvfree(iter->batch);
	iter->batch = new_batch;
	iter->max_sk = new_batch_sz;

	return 0;
}

static struct sock *bpf_iter_unix_batch(struct seq_file *seq,
					loff_t *pos)
{
	struct bpf_unix_iter_state *iter = seq->private;
	unsigned int expected;
	bool resized = false;
	struct sock *sk;

	if (iter->st_bucket_done)
		*pos = set_bucket_offset(get_bucket(*pos) + 1, 1);

again:
	/* Get a new batch */
	iter->cur_sk = 0;
	iter->end_sk = 0;

	sk = unix_get_first(seq, pos);
	if (!sk)
		return NULL; /* Done */

	expected = bpf_iter_unix_hold_batch(seq, sk);

	if (iter->end_sk == expected) {
		iter->st_bucket_done = true;
		return sk;
	}

	if (!resized && !bpf_iter_unix_realloc_batch(iter, expected * 3 / 2)) {
		resized = true;
		goto again;
	}

	return sk;
}

static void *bpf_iter_unix_seq_start(struct seq_file *seq, loff_t *pos)
{
	if (!*pos)
		return SEQ_START_TOKEN;

	/* bpf iter does not support lseek, so it always
	 * continue from where it was stop()-ped.
	 */
	return bpf_iter_unix_batch(seq, pos);
}

static void *bpf_iter_unix_seq_next(struct seq_file *seq, void *v, loff_t *pos)
{
	struct bpf_unix_iter_state *iter = seq->private;
	struct sock *sk;

	/* Whenever seq_next() is called, the iter->cur_sk is
	 * done with seq_show(), so advance to the next sk in
	 * the batch.
	 */
	if (iter->cur_sk < iter->end_sk)
		sock_put(iter->batch[iter->cur_sk++]);

	++*pos;

	if (iter->cur_sk < iter->end_sk)
		sk = iter->batch[iter->cur_sk];
	else
		sk = bpf_iter_unix_batch(seq, pos);

	return sk;
}

static int bpf_iter_unix_seq_show(struct seq_file *seq, void *v)
{
	struct bpf_iter_meta meta;
	struct bpf_prog *prog;
	struct sock *sk = v;
	uid_t uid;
	bool slow;
	int ret;

	if (v == SEQ_START_TOKEN)
		return 0;

	slow = lock_sock_fast(sk);

	if (unlikely(sk_unhashed(sk))) {
		ret = SEQ_SKIP;
		goto unlock;
	}

	uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
	meta.seq = seq;
	prog = bpf_iter_get_info(&meta, false);
	return unix_prog_seq_show(prog, &meta, v, uid);
	ret = unix_prog_seq_show(prog, &meta, v, uid);
unlock:
	unlock_sock_fast(sk, slow);
	return ret;
}

static void bpf_iter_unix_seq_stop(struct seq_file *seq, void *v)
{
	struct bpf_unix_iter_state *iter = seq->private;
	struct bpf_iter_meta meta;
	struct bpf_prog *prog;

@@ -3393,12 +3543,13 @@ static void bpf_iter_unix_seq_stop(struct seq_file *seq, void *v)
			(void)unix_prog_seq_show(prog, &meta, v, 0);
	}

	unix_seq_stop(seq, v);
	if (iter->cur_sk < iter->end_sk)
		bpf_iter_unix_put_batch(iter);
}

static const struct seq_operations bpf_iter_unix_seq_ops = {
	.start	= unix_seq_start,
	.next	= unix_seq_next,
	.start	= bpf_iter_unix_seq_start,
	.next	= bpf_iter_unix_seq_next,
	.stop	= bpf_iter_unix_seq_stop,
	.show	= bpf_iter_unix_seq_show,
};
@@ -3447,13 +3598,55 @@ static struct pernet_operations unix_net_ops = {
DEFINE_BPF_ITER_FUNC(unix, struct bpf_iter_meta *meta,
		     struct unix_sock *unix_sk, uid_t uid)

#define INIT_BATCH_SZ 16

static int bpf_iter_init_unix(void *priv_data, struct bpf_iter_aux_info *aux)
{
	struct bpf_unix_iter_state *iter = priv_data;
	int err;

	err = bpf_iter_init_seq_net(priv_data, aux);
	if (err)
		return err;

	err = bpf_iter_unix_realloc_batch(iter, INIT_BATCH_SZ);
	if (err) {
		bpf_iter_fini_seq_net(priv_data);
		return err;
	}

	return 0;
}

static void bpf_iter_fini_unix(void *priv_data)
{
	struct bpf_unix_iter_state *iter = priv_data;

	bpf_iter_fini_seq_net(priv_data);
	kvfree(iter->batch);
}

static const struct bpf_iter_seq_info unix_seq_info = {
	.seq_ops		= &bpf_iter_unix_seq_ops,
	.init_seq_private	= bpf_iter_init_seq_net,
	.fini_seq_private	= bpf_iter_fini_seq_net,
	.seq_priv_size		= sizeof(struct seq_net_private),
	.init_seq_private	= bpf_iter_init_unix,
	.fini_seq_private	= bpf_iter_fini_unix,
	.seq_priv_size		= sizeof(struct bpf_unix_iter_state),
};

static const struct bpf_func_proto *
bpf_iter_unix_get_func_proto(enum bpf_func_id func_id,
			     const struct bpf_prog *prog)
{
	switch (func_id) {
	case BPF_FUNC_setsockopt:
		return &bpf_sk_setsockopt_proto;
	case BPF_FUNC_getsockopt:
		return &bpf_sk_getsockopt_proto;
	default:
		return NULL;
	}
}

static struct bpf_iter_reg unix_reg_info = {
	.target			= "unix",
	.ctx_arg_info_size	= 1,
@@ -3461,6 +3654,7 @@ static struct bpf_iter_reg unix_reg_info = {
		{ offsetof(struct bpf_iter__unix, unix_sk),
		  PTR_TO_BTF_ID_OR_NULL },
	},
	.get_func_proto         = bpf_iter_unix_get_func_proto,
	.seq_info		= &unix_seq_info,
};

+100 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
/* Copyright Amazon.com Inc. or its affiliates. */
#include <sys/socket.h>
#include <sys/un.h>
#include <test_progs.h>
#include "bpf_iter_setsockopt_unix.skel.h"

#define NR_CASES 5

static int create_unix_socket(struct bpf_iter_setsockopt_unix *skel)
{
	struct sockaddr_un addr = {
		.sun_family = AF_UNIX,
		.sun_path = "",
	};
	socklen_t len;
	int fd, err;

	fd = socket(AF_UNIX, SOCK_STREAM, 0);
	if (!ASSERT_NEQ(fd, -1, "socket"))
		return -1;

	len = offsetof(struct sockaddr_un, sun_path);
	err = bind(fd, (struct sockaddr *)&addr, len);
	if (!ASSERT_OK(err, "bind"))
		return -1;

	len = sizeof(addr);
	err = getsockname(fd, (struct sockaddr *)&addr, &len);
	if (!ASSERT_OK(err, "getsockname"))
		return -1;

	memcpy(&skel->bss->sun_path, &addr.sun_path,
	       len - offsetof(struct sockaddr_un, sun_path));

	return fd;
}

static void test_sndbuf(struct bpf_iter_setsockopt_unix *skel, int fd)
{
	socklen_t optlen;
	int i, err;

	for (i = 0; i < NR_CASES; i++) {
		if (!ASSERT_NEQ(skel->data->sndbuf_getsockopt[i], -1,
				"bpf_(get|set)sockopt"))
			return;

		err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF,
				 &(skel->data->sndbuf_setsockopt[i]),
				 sizeof(skel->data->sndbuf_setsockopt[i]));
		if (!ASSERT_OK(err, "setsockopt"))
			return;

		optlen = sizeof(skel->bss->sndbuf_getsockopt_expected[i]);
		err = getsockopt(fd, SOL_SOCKET, SO_SNDBUF,
				 &(skel->bss->sndbuf_getsockopt_expected[i]),
				 &optlen);
		if (!ASSERT_OK(err, "getsockopt"))
			return;

		if (!ASSERT_EQ(skel->data->sndbuf_getsockopt[i],
			       skel->bss->sndbuf_getsockopt_expected[i],
			       "bpf_(get|set)sockopt"))
			return;
	}
}

void test_bpf_iter_setsockopt_unix(void)
{
	struct bpf_iter_setsockopt_unix *skel;
	int err, unix_fd, iter_fd;
	char buf;

	skel = bpf_iter_setsockopt_unix__open_and_load();
	if (!ASSERT_OK_PTR(skel, "open_and_load"))
		return;

	unix_fd = create_unix_socket(skel);
	if (!ASSERT_NEQ(unix_fd, -1, "create_unix_server"))
		goto destroy;

	skel->links.change_sndbuf = bpf_program__attach_iter(skel->progs.change_sndbuf, NULL);
	if (!ASSERT_OK_PTR(skel->links.change_sndbuf, "bpf_program__attach_iter"))
		goto destroy;

	iter_fd = bpf_iter_create(bpf_link__fd(skel->links.change_sndbuf));
	if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create"))
		goto destroy;

	while ((err = read(iter_fd, &buf, sizeof(buf))) == -1 &&
	       errno == EAGAIN)
		;
	if (!ASSERT_OK(err, "read iter error"))
		goto destroy;

	test_sndbuf(skel, unix_fd);
destroy:
	bpf_iter_setsockopt_unix__destroy(skel);
}
+60 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
/* Copyright Amazon.com Inc. or its affiliates. */
#include "bpf_iter.h"
#include "bpf_tracing_net.h"
#include <bpf/bpf_helpers.h>
#include <limits.h>

#define AUTOBIND_LEN 6
char sun_path[AUTOBIND_LEN];

#define NR_CASES 5
int sndbuf_setsockopt[NR_CASES] = {-1, 0, 8192, INT_MAX / 2, INT_MAX};
int sndbuf_getsockopt[NR_CASES] = {-1, -1, -1, -1, -1};
int sndbuf_getsockopt_expected[NR_CASES];

static inline int cmpname(struct unix_sock *unix_sk)
{
	int i;

	for (i = 0; i < AUTOBIND_LEN; i++) {
		if (unix_sk->addr->name->sun_path[i] != sun_path[i])
			return -1;
	}

	return 0;
}

SEC("iter/unix")
int change_sndbuf(struct bpf_iter__unix *ctx)
{
	struct unix_sock *unix_sk = ctx->unix_sk;
	int i, err;

	if (!unix_sk || !unix_sk->addr)
		return 0;

	if (unix_sk->addr->name->sun_path[0])
		return 0;

	if (cmpname(unix_sk))
		return 0;

	for (i = 0; i < NR_CASES; i++) {
		err = bpf_setsockopt(unix_sk, SOL_SOCKET, SO_SNDBUF,
				     &sndbuf_setsockopt[i],
				     sizeof(sndbuf_setsockopt[i]));
		if (err)
			break;

		err = bpf_getsockopt(unix_sk, SOL_SOCKET, SO_SNDBUF,
				     &sndbuf_getsockopt[i],
				     sizeof(sndbuf_getsockopt[i]));
		if (err)
			break;
	}

	return 0;
}

char _license[] SEC("license") = "GPL";
+1 −1
Original line number Diff line number Diff line
@@ -63,7 +63,7 @@ int dump_unix(struct bpf_iter__unix *ctx)
			BPF_SEQ_PRINTF(seq, " @");

			for (i = 1; i < len; i++) {
				/* unix_mkname() tests this upper bound. */
				/* unix_validate_addr() tests this upper bound. */
				if (i >= sizeof(struct sockaddr_un))
					break;

+2 −0
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@
#define AF_INET			2
#define AF_INET6		10

#define SOL_SOCKET		1
#define SO_SNDBUF		7
#define __SO_ACCEPTCON		(1 << 16)

#define SOL_TCP			6