Commit 30bab7cd authored by Jiri Pirko's avatar Jiri Pirko Committed by Jakub Kicinski
Browse files

net: devlink: make sure that devlink_try_get() works with valid pointer during xarray iteration



Remove dependency on devlink_mutex during devlinks xarray iteration.

The reason is that devlink_register/unregister() functions taking
devlink_mutex would deadlock during devlink reload operation of devlink
instance which registers/unregisters nested devlink instances.

The devlinks xarray consistency is ensured internally by xarray.
There is a reference taken when working with devlink using
devlink_try_get(). But there is no guarantee that devlink pointer
picked during xarray iteration is not freed before devlink_try_get()
is called.

Make sure that devlink_try_get() works with valid pointer.
Achieve it by:
1) Splitting devlink_put() so the completion is sent only
   after grace period. Completion unblocks the devlink_unregister()
   routine, which is followed-up by devlink_free()
2) During devlinks xa_array iteration, get devlink pointer from xa_array
   holding RCU read lock and taking reference using devlink_try_get()
   before unlock.

Signed-off-by: default avatarJiri Pirko <jiri@nvidia.com>
Reviewed-by: default avatarJakub Kicinski <kuba@kernel.org>
Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent 35d099da
Loading
Loading
Loading
Loading
+80 −91
Original line number Diff line number Diff line
@@ -70,6 +70,7 @@ struct devlink {
	u8 reload_failed:1;
	refcount_t refcount;
	struct completion comp;
	struct rcu_head rcu;
	char priv[] __aligned(NETDEV_ALIGN);
};

@@ -221,8 +222,6 @@ static DEFINE_XARRAY_FLAGS(devlinks, XA_FLAGS_ALLOC);
/* devlink_mutex
 *
 * An overall lock guarding every operation coming from userspace.
 * It also guards devlink devices list and it is taken when
 * driver registers/unregisters it.
 */
static DEFINE_MUTEX(devlink_mutex);

@@ -232,10 +231,21 @@ struct net *devlink_net(const struct devlink *devlink)
}
EXPORT_SYMBOL_GPL(devlink_net);

static void __devlink_put_rcu(struct rcu_head *head)
{
	struct devlink *devlink = container_of(head, struct devlink, rcu);

	complete(&devlink->comp);
}

void devlink_put(struct devlink *devlink)
{
	if (refcount_dec_and_test(&devlink->refcount))
		complete(&devlink->comp);
		/* Make sure unregister operation that may await the completion
		 * is unblocked only after all users are after the end of
		 * RCU grace period.
		 */
		call_rcu(&devlink->rcu, __devlink_put_rcu);
}

struct devlink *__must_check devlink_try_get(struct devlink *devlink)
@@ -278,12 +288,55 @@ void devl_unlock(struct devlink *devlink)
}
EXPORT_SYMBOL_GPL(devl_unlock);

static struct devlink *
devlinks_xa_find_get(unsigned long *indexp, xa_mark_t filter,
		     void * (*xa_find_fn)(struct xarray *, unsigned long *,
					  unsigned long, xa_mark_t))
{
	struct devlink *devlink;

	rcu_read_lock();
retry:
	devlink = xa_find_fn(&devlinks, indexp, ULONG_MAX, DEVLINK_REGISTERED);
	if (!devlink)
		goto unlock;
	/* For a possible retry, the xa_find_after() should be always used */
	xa_find_fn = xa_find_after;
	if (!devlink_try_get(devlink))
		goto retry;
unlock:
	rcu_read_unlock();
	return devlink;
}

static struct devlink *devlinks_xa_find_get_first(unsigned long *indexp,
						  xa_mark_t filter)
{
	return devlinks_xa_find_get(indexp, filter, xa_find);
}

static struct devlink *devlinks_xa_find_get_next(unsigned long *indexp,
						 xa_mark_t filter)
{
	return devlinks_xa_find_get(indexp, filter, xa_find_after);
}

/* Iterate over devlink pointers which were possible to get reference to.
 * devlink_put() needs to be called for each iterated devlink pointer
 * in loop body in order to release the reference.
 */
#define devlinks_xa_for_each_get(index, devlink, filter)			\
	for (index = 0, devlink = devlinks_xa_find_get_first(&index, filter);	\
	     devlink; devlink = devlinks_xa_find_get_next(&index, filter))

#define devlinks_xa_for_each_registered_get(index, devlink)			\
	devlinks_xa_for_each_get(index, devlink, DEVLINK_REGISTERED)

static struct devlink *devlink_get_from_attrs(struct net *net,
					      struct nlattr **attrs)
{
	struct devlink *devlink;
	unsigned long index;
	bool found = false;
	char *busname;
	char *devname;

@@ -293,21 +346,15 @@ static struct devlink *devlink_get_from_attrs(struct net *net,
	busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]);
	devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);

	lockdep_assert_held(&devlink_mutex);

	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
	devlinks_xa_for_each_registered_get(index, devlink) {
		if (strcmp(devlink->dev->bus->name, busname) == 0 &&
		    strcmp(dev_name(devlink->dev), devname) == 0 &&
		    net_eq(devlink_net(devlink), net)) {
			found = true;
			break;
		}
		    net_eq(devlink_net(devlink), net))
			return devlink;
		devlink_put(devlink);
	}

	if (!found || !devlink_try_get(devlink))
		devlink = ERR_PTR(-ENODEV);

	return devlink;
	return ERR_PTR(-ENODEV);
}

static struct devlink_port *devlink_port_get_by_index(struct devlink *devlink,
@@ -1329,10 +1376,7 @@ static int devlink_nl_cmd_rate_get_dumpit(struct sk_buff *msg,
	int err = 0;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -1432,10 +1476,7 @@ static int devlink_nl_cmd_get_dumpit(struct sk_buff *msg,
	int err;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
			devlink_put(devlink);
			continue;
@@ -1495,10 +1536,7 @@ static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg,
	int err;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -2177,10 +2215,7 @@ static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg,
	int err;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -2449,10 +2484,7 @@ static int devlink_nl_cmd_sb_get_dumpit(struct sk_buff *msg,
	int err;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -2601,10 +2633,7 @@ static int devlink_nl_cmd_sb_pool_get_dumpit(struct sk_buff *msg,
	int err = 0;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
		    !devlink->ops->sb_pool_get)
			goto retry;
@@ -2822,10 +2851,7 @@ static int devlink_nl_cmd_sb_port_pool_get_dumpit(struct sk_buff *msg,
	int err = 0;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
		    !devlink->ops->sb_port_pool_get)
			goto retry;
@@ -3071,10 +3097,7 @@ devlink_nl_cmd_sb_tc_pool_bind_get_dumpit(struct sk_buff *msg,
	int err = 0;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
		    !devlink->ops->sb_tc_pool_bind_get)
			goto retry;
@@ -5158,10 +5181,7 @@ static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg,
	int err = 0;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -5393,10 +5413,7 @@ static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg,
	int err = 0;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -5977,10 +5994,7 @@ static int devlink_nl_cmd_region_get_dumpit(struct sk_buff *msg,
	int err = 0;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -6511,10 +6525,7 @@ static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg,
	int err = 0;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -7691,10 +7702,7 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg,
	int err;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry_rep;

@@ -7721,10 +7729,7 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg,
		devlink_put(devlink);
	}

	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry_port;

@@ -8291,10 +8296,7 @@ static int devlink_nl_cmd_trap_get_dumpit(struct sk_buff *msg,
	int err;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -8518,10 +8520,7 @@ static int devlink_nl_cmd_trap_group_get_dumpit(struct sk_buff *msg,
	int err;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -8832,10 +8831,7 @@ static int devlink_nl_cmd_trap_policer_get_dumpit(struct sk_buff *msg,
	int err;

	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
			goto retry;

@@ -9589,10 +9585,8 @@ void devlink_register(struct devlink *devlink)
	ASSERT_DEVLINK_NOT_REGISTERED(devlink);
	/* Make sure that we are in .probe() routine */

	mutex_lock(&devlink_mutex);
	xa_set_mark(&devlinks, devlink->index, DEVLINK_REGISTERED);
	devlink_notify_register(devlink);
	mutex_unlock(&devlink_mutex);
}
EXPORT_SYMBOL_GPL(devlink_register);

@@ -9609,10 +9603,8 @@ void devlink_unregister(struct devlink *devlink)
	devlink_put(devlink);
	wait_for_completion(&devlink->comp);

	mutex_lock(&devlink_mutex);
	devlink_notify_unregister(devlink);
	xa_clear_mark(&devlinks, devlink->index, DEVLINK_REGISTERED);
	mutex_unlock(&devlink_mutex);
}
EXPORT_SYMBOL_GPL(devlink_unregister);

@@ -12281,10 +12273,7 @@ static void __net_exit devlink_pernet_pre_exit(struct net *net)
	 * all devlink instances from this namespace into init_net.
	 */
	mutex_lock(&devlink_mutex);
	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
		if (!devlink_try_get(devlink))
			continue;

	devlinks_xa_for_each_registered_get(index, devlink) {
		if (!net_eq(devlink_net(devlink), net))
			goto retry;