Commit 084d0c13 authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

Merge branch 'net-packet-make-packet_fanout-arr-size-configurable-up-to-64k'

Tanner Love says:

====================
net/packet: make packet_fanout.arr size configurable up to 64K

First patch makes the change; second patch adds unit tests.
====================

Link: https://lore.kernel.org/r/20201106180741.2839668-1-tannerlove.kernel@gmail.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents a3ce2b10 1db32acf
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
#ifndef __LINUX_IF_PACKET_H
#define __LINUX_IF_PACKET_H

#include <asm/byteorder.h>
#include <linux/types.h>

struct sockaddr_pkt {
@@ -296,6 +297,17 @@ struct packet_mreq {
	unsigned char	mr_address[8];
};

struct fanout_args {
#if defined(__LITTLE_ENDIAN_BITFIELD)
	__u16		id;
	__u16		type_flags;
#else
	__u16		type_flags;
	__u16		id;
#endif
	__u32		max_num_members;
};

#define PACKET_MR_MULTICAST	0
#define PACKET_MR_PROMISC	1
#define PACKET_MR_ALLMULTI	2
+25 −12
Original line number Diff line number Diff line
@@ -1636,13 +1636,15 @@ static bool fanout_find_new_id(struct sock *sk, u16 *new_id)
	return false;
}

static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
static int fanout_add(struct sock *sk, struct fanout_args *args)
{
	struct packet_rollover *rollover = NULL;
	struct packet_sock *po = pkt_sk(sk);
	u16 type_flags = args->type_flags;
	struct packet_fanout *f, *match;
	u8 type = type_flags & 0xff;
	u8 flags = type_flags >> 8;
	u16 id = args->id;
	int err;

	switch (type) {
@@ -1700,11 +1702,21 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
		}
	}
	err = -EINVAL;
	if (match && match->flags != flags)
	if (match) {
		if (match->flags != flags)
			goto out;
	if (!match) {
		if (args->max_num_members &&
		    args->max_num_members != match->max_num_members)
			goto out;
	} else {
		if (args->max_num_members > PACKET_FANOUT_MAX)
			goto out;
		if (!args->max_num_members)
			/* legacy PACKET_FANOUT_MAX */
			args->max_num_members = 256;
		err = -ENOMEM;
		match = kzalloc(sizeof(*match), GFP_KERNEL);
		match = kvzalloc(struct_size(match, arr, args->max_num_members),
				 GFP_KERNEL);
		if (!match)
			goto out;
		write_pnet(&match->net, sock_net(sk));
@@ -1720,6 +1732,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
		match->prot_hook.func = packet_rcv_fanout;
		match->prot_hook.af_packet_priv = match;
		match->prot_hook.id_match = match_fanout_group;
		match->max_num_members = args->max_num_members;
		list_add(&match->list, &fanout_list);
	}
	err = -EINVAL;
@@ -1730,7 +1743,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
	    match->prot_hook.type == po->prot_hook.type &&
	    match->prot_hook.dev == po->prot_hook.dev) {
		err = -ENOSPC;
		if (refcount_read(&match->sk_ref) < PACKET_FANOUT_MAX) {
		if (refcount_read(&match->sk_ref) < match->max_num_members) {
			__dev_remove_pack(&po->prot_hook);
			po->fanout = match;
			po->rollover = rollover;
@@ -1744,7 +1757,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)

	if (err && !refcount_read(&match->sk_ref)) {
		list_del(&match->list);
		kfree(match);
		kvfree(match);
	}

out:
@@ -3075,7 +3088,7 @@ static int packet_release(struct socket *sock)
	kfree(po->rollover);
	if (f) {
		fanout_release_data(f);
		kfree(f);
		kvfree(f);
	}
	/*
	 *	Now the socket is dead. No more input will appear.
@@ -3866,14 +3879,14 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
	}
	case PACKET_FANOUT:
	{
		int val;
		struct fanout_args args = { 0 };

		if (optlen != sizeof(val))
		if (optlen != sizeof(int) && optlen != sizeof(args))
			return -EINVAL;
		if (copy_from_sockptr(&val, optval, sizeof(val)))
		if (copy_from_sockptr(&args, optval, optlen))
			return -EFAULT;

		return fanout_add(sk, val & 0xffff, val >> 16);
		return fanout_add(sk, &args);
	}
	case PACKET_FANOUT_DATA:
	{
+3 −2
Original line number Diff line number Diff line
@@ -77,11 +77,12 @@ struct packet_ring_buffer {
};

extern struct mutex fanout_mutex;
#define PACKET_FANOUT_MAX	256
#define PACKET_FANOUT_MAX	(1 << 16)

struct packet_fanout {
	possible_net_t		net;
	unsigned int		num_members;
	u32			max_num_members;
	u16			id;
	u8			type;
	u8			flags;
@@ -90,10 +91,10 @@ struct packet_fanout {
		struct bpf_prog __rcu	*bpf_prog;
	};
	struct list_head	list;
	struct sock		*arr[PACKET_FANOUT_MAX];
	spinlock_t		lock;
	refcount_t		sk_ref;
	struct packet_type	prot_hook ____cacheline_aligned_in_smp;
	struct sock		*arr[];
};

struct packet_rollover {
+69 −3
Original line number Diff line number Diff line
@@ -56,12 +56,15 @@

#define RING_NUM_FRAMES			20

static uint32_t cfg_max_num_members;

/* Open a socket in a given fanout mode.
 * @return -1 if mode is bad, a valid socket otherwise */
static int sock_fanout_open(uint16_t typeflags, uint16_t group_id)
{
	struct sockaddr_ll addr = {0};
	int fd, val;
	struct fanout_args args;
	int fd, val, err;

	fd = socket(PF_PACKET, SOCK_RAW, 0);
	if (fd < 0) {
@@ -83,8 +86,18 @@ static int sock_fanout_open(uint16_t typeflags, uint16_t group_id)
		exit(1);
	}

	if (cfg_max_num_members) {
		args.id = group_id;
		args.type_flags = typeflags;
		args.max_num_members = cfg_max_num_members;
		err = setsockopt(fd, SOL_PACKET, PACKET_FANOUT, &args,
				 sizeof(args));
	} else {
		val = (((int) typeflags) << 16) | group_id;
	if (setsockopt(fd, SOL_PACKET, PACKET_FANOUT, &val, sizeof(val))) {
		err = setsockopt(fd, SOL_PACKET, PACKET_FANOUT, &val,
				 sizeof(val));
	}
	if (err) {
		if (close(fd)) {
			perror("close packet");
			exit(1);
@@ -286,6 +299,56 @@ static void test_control_group(void)
	}
}

/* Test illegal max_num_members values */
static void test_control_group_max_num_members(void)
{
	int fds[3];

	fprintf(stderr, "test: control multiple sockets, max_num_members\n");

	/* expected failure on greater than PACKET_FANOUT_MAX */
	cfg_max_num_members = (1 << 16) + 1;
	if (sock_fanout_open(PACKET_FANOUT_HASH, 0) != -1) {
		fprintf(stderr, "ERROR: max_num_members > PACKET_FANOUT_MAX\n");
		exit(1);
	}

	cfg_max_num_members = 256;
	fds[0] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
	if (fds[0] == -1) {
		fprintf(stderr, "ERROR: failed open\n");
		exit(1);
	}

	/* expected failure on joining group with different max_num_members */
	cfg_max_num_members = 257;
	if (sock_fanout_open(PACKET_FANOUT_HASH, 0) != -1) {
		fprintf(stderr, "ERROR: set different max_num_members\n");
		exit(1);
	}

	/* success on joining group with same max_num_members */
	cfg_max_num_members = 256;
	fds[1] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
	if (fds[1] == -1) {
		fprintf(stderr, "ERROR: failed to join group\n");
		exit(1);
	}

	/* success on joining group with max_num_members unspecified */
	cfg_max_num_members = 0;
	fds[2] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
	if (fds[2] == -1) {
		fprintf(stderr, "ERROR: failed to join group\n");
		exit(1);
	}

	if (close(fds[2]) || close(fds[1]) || close(fds[0])) {
		fprintf(stderr, "ERROR: closing sockets\n");
		exit(1);
	}
}

/* Test creating a unique fanout group ids */
static void test_unique_fanout_group_ids(void)
{
@@ -426,8 +489,11 @@ int main(int argc, char **argv)

	test_control_single();
	test_control_group();
	test_control_group_max_num_members();
	test_unique_fanout_group_ids();

	/* PACKET_FANOUT_MAX */
	cfg_max_num_members = 1 << 16;
	/* find a set of ports that do not collide onto the same socket */
	ret = test_datapath(PACKET_FANOUT_HASH, port_off,
			    expect_hash[0], expect_hash[1]);