Commit 7c678829 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'mptcp-Include-multiple-address-ids-in-RM_ADDR'



Mat Martineau says:

====================
mptcp: Include multiple address ids in RM_ADDR

Here's a patch series from the MPTCP tree that extends the capabilities
of the MPTCP RM_ADDR header.

MPTCP peers can exchange information about their IP addresses that are
available for additional MPTCP subflows. IP addresses are advertised
with an ADD_ADDR header type, and those advertisements are revoked with
the RM_ADDR header type. RFC 8684 allows the RM_ADDR header to include
more than one address ID, so multiple advertisements can be revoked in a
single header. Previous kernel versions have only used RM_ADDR with a
single address ID, so multiple removals required multiple packets.

Patches 1-4 plumb address id list structures around the MPTCP code,
where before only a single address ID was passed.

Patches 5-8 make use of the address lists at the path manager layer that
tracks available addresses for both peers.

Patches 9-11 update the selftests to cover the new use of RM_ADDR with
multiple address IDs.
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents e9e90a70 d2c4333a
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -34,6 +34,13 @@ struct mptcp_ext {
	/* one byte hole */
};

#define MPTCP_RM_IDS_MAX	8

struct mptcp_rm_list {
	u8 ids[MPTCP_RM_IDS_MAX];
	u8 nr;
};

struct mptcp_out_options {
#if IS_ENABLED(CONFIG_MPTCP)
	u16 suboptions;
@@ -48,7 +55,7 @@ struct mptcp_out_options {
	u8 addr_id;
	u16 port;
	u64 ahmac;
	u8 rm_id;
	struct mptcp_rm_list rm_list;
	u8 join_id;
	u8 backup;
	u32 nonce;
+35 −12
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ static void mptcp_parse_option(const struct sk_buff *skb,
	int expected_opsize;
	u8 version;
	u8 flags;
	u8 i;

	switch (subtype) {
	case MPTCPOPT_MP_CAPABLE:
@@ -272,14 +273,17 @@ static void mptcp_parse_option(const struct sk_buff *skb,
		break;

	case MPTCPOPT_RM_ADDR:
		if (opsize != TCPOLEN_MPTCP_RM_ADDR_BASE)
		if (opsize < TCPOLEN_MPTCP_RM_ADDR_BASE + 1 ||
		    opsize > TCPOLEN_MPTCP_RM_ADDR_BASE + MPTCP_RM_IDS_MAX)
			break;

		ptr++;

		mp_opt->rm_addr = 1;
		mp_opt->rm_id = *ptr++;
		pr_debug("RM_ADDR: id=%d", mp_opt->rm_id);
		mp_opt->rm_list.nr = opsize - TCPOLEN_MPTCP_RM_ADDR_BASE;
		for (i = 0; i < mp_opt->rm_list.nr; i++)
			mp_opt->rm_list.ids[i] = *ptr++;
		pr_debug("RM_ADDR: rm_list_nr=%d", mp_opt->rm_list.nr);
		break;

	case MPTCPOPT_MP_PRIO:
@@ -674,20 +678,25 @@ static bool mptcp_established_options_rm_addr(struct sock *sk,
{
	struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
	struct mptcp_sock *msk = mptcp_sk(subflow->conn);
	u8 rm_id;
	struct mptcp_rm_list rm_list;
	int i, len;

	if (!mptcp_pm_should_rm_signal(msk) ||
	    !(mptcp_pm_rm_addr_signal(msk, remaining, &rm_id)))
	    !(mptcp_pm_rm_addr_signal(msk, remaining, &rm_list)))
		return false;

	if (remaining < TCPOLEN_MPTCP_RM_ADDR_BASE)
	len = mptcp_rm_addr_len(&rm_list);
	if (len < 0)
		return false;
	if (remaining < len)
		return false;

	*size = TCPOLEN_MPTCP_RM_ADDR_BASE;
	*size = len;
	opts->suboptions |= OPTION_MPTCP_RM_ADDR;
	opts->rm_id = rm_id;
	opts->rm_list = rm_list;

	pr_debug("rm_id=%d", opts->rm_id);
	for (i = 0; i < opts->rm_list.nr; i++)
		pr_debug("rm_list_ids[%d]=%d", i, opts->rm_list.ids[i]);

	return true;
}
@@ -1038,7 +1047,7 @@ void mptcp_incoming_options(struct sock *sk, struct sk_buff *skb)
	}

	if (mp_opt.rm_addr) {
		mptcp_pm_rm_addr_received(msk, mp_opt.rm_id);
		mptcp_pm_rm_addr_received(msk, &mp_opt.rm_list);
		mp_opt.rm_addr = 0;
	}

@@ -1217,9 +1226,23 @@ void mptcp_write_options(__be32 *ptr, const struct tcp_sock *tp,
	}

	if (OPTION_MPTCP_RM_ADDR & opts->suboptions) {
		u8 i = 1;

		*ptr++ = mptcp_option(MPTCPOPT_RM_ADDR,
				      TCPOLEN_MPTCP_RM_ADDR_BASE,
				      0, opts->rm_id);
				      TCPOLEN_MPTCP_RM_ADDR_BASE + opts->rm_list.nr,
				      0, opts->rm_list.ids[0]);

		while (i < opts->rm_list.nr) {
			u8 id1, id2, id3, id4;

			id1 = opts->rm_list.ids[i];
			id2 = i + 1 < opts->rm_list.nr ? opts->rm_list.ids[i + 1] : TCPOPT_NOP;
			id3 = i + 2 < opts->rm_list.nr ? opts->rm_list.ids[i + 2] : TCPOPT_NOP;
			id4 = i + 3 < opts->rm_list.nr ? opts->rm_list.ids[i + 3] : TCPOPT_NOP;
			put_unaligned_be32(id1 << 24 | id2 << 16 | id3 << 8 | id4, ptr);
			ptr += 1;
			i += 4;
		}
	}

	if (OPTION_MPTCP_PRIO & opts->suboptions) {
+24 −15
Original line number Diff line number Diff line
@@ -39,29 +39,29 @@ int mptcp_pm_announce_addr(struct mptcp_sock *msk,
	return 0;
}

int mptcp_pm_remove_addr(struct mptcp_sock *msk, u8 local_id)
int mptcp_pm_remove_addr(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list)
{
	u8 rm_addr = READ_ONCE(msk->pm.addr_signal);

	pr_debug("msk=%p, local_id=%d", msk, local_id);
	pr_debug("msk=%p, rm_list_nr=%d", msk, rm_list->nr);

	if (rm_addr) {
		pr_warn("addr_signal error, rm_addr=%d", rm_addr);
		return -EINVAL;
	}

	msk->pm.rm_id = local_id;
	msk->pm.rm_list_tx = *rm_list;
	rm_addr |= BIT(MPTCP_RM_ADDR_SIGNAL);
	WRITE_ONCE(msk->pm.addr_signal, rm_addr);
	return 0;
}

int mptcp_pm_remove_subflow(struct mptcp_sock *msk, u8 local_id)
int mptcp_pm_remove_subflow(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list)
{
	pr_debug("msk=%p, local_id=%d", msk, local_id);
	pr_debug("msk=%p, rm_list_nr=%d", msk, rm_list->nr);

	spin_lock_bh(&msk->pm.lock);
	mptcp_pm_nl_rm_subflow_received(msk, local_id);
	mptcp_pm_nl_rm_subflow_received(msk, rm_list);
	spin_unlock_bh(&msk->pm.lock);
	return 0;
}
@@ -205,17 +205,20 @@ void mptcp_pm_add_addr_send_ack(struct mptcp_sock *msk)
	mptcp_pm_schedule_work(msk, MPTCP_PM_ADD_ADDR_SEND_ACK);
}

void mptcp_pm_rm_addr_received(struct mptcp_sock *msk, u8 rm_id)
void mptcp_pm_rm_addr_received(struct mptcp_sock *msk,
			       const struct mptcp_rm_list *rm_list)
{
	struct mptcp_pm_data *pm = &msk->pm;
	u8 i;

	pr_debug("msk=%p remote_id=%d", msk, rm_id);
	pr_debug("msk=%p remote_ids_nr=%d", msk, rm_list->nr);

	mptcp_event_addr_removed(msk, rm_id);
	for (i = 0; i < rm_list->nr; i++)
		mptcp_event_addr_removed(msk, rm_list->ids[i]);

	spin_lock_bh(&pm->lock);
	mptcp_pm_schedule_work(msk, MPTCP_PM_RM_ADDR_RECEIVED);
	pm->rm_id = rm_id;
	pm->rm_list_rx = *rm_list;
	spin_unlock_bh(&pm->lock);
}

@@ -258,9 +261,9 @@ bool mptcp_pm_add_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
}

bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
			     u8 *rm_id)
			     struct mptcp_rm_list *rm_list)
{
	int ret = false;
	int ret = false, len;

	spin_lock_bh(&msk->pm.lock);

@@ -268,10 +271,15 @@ bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
	if (!mptcp_pm_should_rm_signal(msk))
		goto out_unlock;

	if (remaining < TCPOLEN_MPTCP_RM_ADDR_BASE)
	len = mptcp_rm_addr_len(&msk->pm.rm_list_tx);
	if (len < 0) {
		WRITE_ONCE(msk->pm.addr_signal, 0);
		goto out_unlock;
	}
	if (remaining < len)
		goto out_unlock;

	*rm_id = msk->pm.rm_id;
	*rm_list = msk->pm.rm_list_tx;
	WRITE_ONCE(msk->pm.addr_signal, 0);
	ret = true;

@@ -291,7 +299,8 @@ void mptcp_pm_data_init(struct mptcp_sock *msk)
	msk->pm.add_addr_accepted = 0;
	msk->pm.local_addr_used = 0;
	msk->pm.subflows = 0;
	msk->pm.rm_id = 0;
	msk->pm.rm_list_tx.nr = 0;
	msk->pm.rm_list_rx.nr = 0;
	WRITE_ONCE(msk->pm.work_pending, false);
	WRITE_ONCE(msk->pm.addr_signal, 0);
	WRITE_ONCE(msk->pm.accept_addr, false);
+101 −38
Original line number Diff line number Diff line
@@ -575,24 +575,27 @@ static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
{
	struct mptcp_subflow_context *subflow, *tmp;
	struct sock *sk = (struct sock *)msk;
	u8 i;

	pr_debug("address rm_id %d", msk->pm.rm_id);
	pr_debug("address rm_list_nr %d", msk->pm.rm_list_rx.nr);

	msk_owned_by_me(msk);

	if (!msk->pm.rm_id)
	if (!msk->pm.rm_list_rx.nr)
		return;

	if (list_empty(&msk->conn_list))
		return;

	for (i = 0; i < msk->pm.rm_list_rx.nr; i++) {
		list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
			struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
			int how = RCV_SHUTDOWN | SEND_SHUTDOWN;

		if (msk->pm.rm_id != subflow->remote_id)
			if (msk->pm.rm_list_rx.ids[i] != subflow->remote_id)
				continue;

			pr_debug(" -> address rm_list_ids[%d]=%u", i, msk->pm.rm_list_rx.ids[i]);
			spin_unlock_bh(&msk->pm.lock);
			mptcp_subflow_shutdown(sk, ssk, how);
			mptcp_close_ssk(sk, ssk, subflow);
@@ -607,6 +610,7 @@ static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
			break;
		}
	}
}

void mptcp_pm_nl_work(struct mptcp_sock *msk)
{
@@ -641,28 +645,32 @@ void mptcp_pm_nl_work(struct mptcp_sock *msk)
	spin_unlock_bh(&msk->pm.lock);
}

void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id)
void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
				     const struct mptcp_rm_list *rm_list)
{
	struct mptcp_subflow_context *subflow, *tmp;
	struct sock *sk = (struct sock *)msk;
	u8 i;

	pr_debug("subflow rm_id %d", rm_id);
	pr_debug("subflow rm_list_nr %d", rm_list->nr);

	msk_owned_by_me(msk);

	if (!rm_id)
	if (!rm_list->nr)
		return;

	if (list_empty(&msk->conn_list))
		return;

	for (i = 0; i < rm_list->nr; i++) {
		list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
			struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
			int how = RCV_SHUTDOWN | SEND_SHUTDOWN;

		if (rm_id != subflow->local_id)
			if (rm_list->ids[i] != subflow->local_id)
				continue;

			pr_debug(" -> subflow rm_list_ids[%d]=%u", i, rm_list->ids[i]);
			spin_unlock_bh(&msk->pm.lock);
			mptcp_subflow_shutdown(sk, ssk, how);
			mptcp_close_ssk(sk, ssk, subflow);
@@ -676,6 +684,7 @@ void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id)
			break;
		}
	}
}

static bool address_use_port(struct mptcp_pm_addr_entry *entry)
{
@@ -1071,12 +1080,15 @@ static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
				      struct mptcp_addr_info *addr,
				      bool force)
{
	struct mptcp_rm_list list = { .nr = 0 };
	bool ret;

	list.ids[list.nr++] = addr->id;

	ret = remove_anno_list_by_saddr(msk, addr);
	if (ret || force) {
		spin_lock_bh(&msk->pm.lock);
		mptcp_pm_remove_addr(msk, addr->id);
		mptcp_pm_remove_addr(msk, &list);
		spin_unlock_bh(&msk->pm.lock);
	}
	return ret;
@@ -1087,9 +1099,12 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
{
	struct mptcp_sock *msk;
	long s_slot = 0, s_num = 0;
	struct mptcp_rm_list list = { .nr = 0 };

	pr_debug("remove_id=%d", addr->id);

	list.ids[list.nr++] = addr->id;

	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
		struct sock *sk = (struct sock *)msk;
		bool remove_subflow;
@@ -1103,7 +1118,7 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
		remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr);
		mptcp_pm_remove_anno_addr(msk, addr, remove_subflow);
		if (remove_subflow)
			mptcp_pm_remove_subflow(msk, addr->id);
			mptcp_pm_remove_subflow(msk, &list);
		release_sock(sk);

next:
@@ -1185,14 +1200,61 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
	return ret;
}

static void __flush_addrs(struct net *net, struct list_head *list)
static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk,
					       struct list_head *rm_list)
{
	struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 };
	struct mptcp_pm_addr_entry *entry;

	list_for_each_entry(entry, rm_list, list) {
		if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
		    alist.nr < MPTCP_RM_IDS_MAX &&
		    slist.nr < MPTCP_RM_IDS_MAX) {
			alist.ids[alist.nr++] = entry->addr.id;
			slist.ids[slist.nr++] = entry->addr.id;
		} else if (remove_anno_list_by_saddr(msk, &entry->addr) &&
			 alist.nr < MPTCP_RM_IDS_MAX) {
			alist.ids[alist.nr++] = entry->addr.id;
		}
	}

	if (alist.nr) {
		spin_lock_bh(&msk->pm.lock);
		mptcp_pm_remove_addr(msk, &alist);
		spin_unlock_bh(&msk->pm.lock);
	}
	if (slist.nr)
		mptcp_pm_remove_subflow(msk, &slist);
}

static void mptcp_nl_remove_addrs_list(struct net *net,
				       struct list_head *rm_list)
{
	long s_slot = 0, s_num = 0;
	struct mptcp_sock *msk;

	if (list_empty(rm_list))
		return;

	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
		struct sock *sk = (struct sock *)msk;

		lock_sock(sk);
		mptcp_pm_remove_addrs_and_subflows(msk, rm_list);
		release_sock(sk);

		sock_put(sk);
		cond_resched();
	}
}

static void __flush_addrs(struct list_head *list)
{
	while (!list_empty(list)) {
		struct mptcp_pm_addr_entry *cur;

		cur = list_entry(list->next,
				 struct mptcp_pm_addr_entry, list);
		mptcp_nl_remove_subflow_and_signal_addr(net, &cur->addr);
		list_del_rcu(&cur->list);
		mptcp_pm_free_addr_entry(cur);
	}
@@ -1217,7 +1279,8 @@ static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
	pernet->next_id = 1;
	bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1);
	spin_unlock_bh(&pernet->lock);
	__flush_addrs(sock_net(skb->sk), &free_list);
	mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list);
	__flush_addrs(&free_list);
	return 0;
}

@@ -1814,7 +1877,7 @@ static void __net_exit pm_nl_exit_net(struct list_head *net_list)
		/* net is removed from namespace list, can't race with
		 * other modifiers
		 */
		__flush_addrs(net, &pernet->local_addr_list);
		__flush_addrs(&pernet->local_addr_list);
	}
}

+19 −8
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@
#define TCPOLEN_MPTCP_ADD_ADDR6_BASE_PORT	22
#define TCPOLEN_MPTCP_PORT_LEN		2
#define TCPOLEN_MPTCP_PORT_ALIGN	2
#define TCPOLEN_MPTCP_RM_ADDR_BASE	4
#define TCPOLEN_MPTCP_RM_ADDR_BASE	3
#define TCPOLEN_MPTCP_PRIO		3
#define TCPOLEN_MPTCP_PRIO_ALIGN	4
#define TCPOLEN_MPTCP_FASTCLOSE		12
@@ -142,7 +142,7 @@ struct mptcp_options_received {
		mpc_map:1,
		__unused:2;
	u8	addr_id;
	u8	rm_id;
	struct mptcp_rm_list rm_list;
	union {
		struct in_addr	addr;
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
@@ -207,7 +207,8 @@ struct mptcp_pm_data {
	u8		local_addr_used;
	u8		subflows;
	u8		status;
	u8		rm_id;
	struct mptcp_rm_list rm_list_tx;
	struct mptcp_rm_list rm_list_rx;
};

struct mptcp_data_frag {
@@ -647,7 +648,8 @@ void mptcp_pm_subflow_closed(struct mptcp_sock *msk, u8 id);
void mptcp_pm_add_addr_received(struct mptcp_sock *msk,
				const struct mptcp_addr_info *addr);
void mptcp_pm_add_addr_send_ack(struct mptcp_sock *msk);
void mptcp_pm_rm_addr_received(struct mptcp_sock *msk, u8 rm_id);
void mptcp_pm_rm_addr_received(struct mptcp_sock *msk,
			       const struct mptcp_rm_list *rm_list);
void mptcp_pm_mp_prio_received(struct sock *sk, u8 bkup);
int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
				 struct mptcp_addr_info *addr,
@@ -661,8 +663,8 @@ mptcp_pm_del_add_timer(struct mptcp_sock *msk,
int mptcp_pm_announce_addr(struct mptcp_sock *msk,
			   const struct mptcp_addr_info *addr,
			   bool echo, bool port);
int mptcp_pm_remove_addr(struct mptcp_sock *msk, u8 local_id);
int mptcp_pm_remove_subflow(struct mptcp_sock *msk, u8 local_id);
int mptcp_pm_remove_addr(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list);
int mptcp_pm_remove_subflow(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list);

void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
		 const struct sock *ssk, gfp_t gfp);
@@ -709,16 +711,25 @@ static inline unsigned int mptcp_add_addr_len(int family, bool echo, bool port)
	return len;
}

static inline int mptcp_rm_addr_len(const struct mptcp_rm_list *rm_list)
{
	if (rm_list->nr == 0 || rm_list->nr > MPTCP_RM_IDS_MAX)
		return -EINVAL;

	return TCPOLEN_MPTCP_RM_ADDR_BASE + roundup(rm_list->nr - 1, 4) + 1;
}

bool mptcp_pm_add_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
			      struct mptcp_addr_info *saddr, bool *echo, bool *port);
bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
			     u8 *rm_id);
			     struct mptcp_rm_list *rm_list);
int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc);

void __init mptcp_pm_nl_init(void);
void mptcp_pm_nl_data_init(struct mptcp_sock *msk);
void mptcp_pm_nl_work(struct mptcp_sock *msk);
void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id);
void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
				     const struct mptcp_rm_list *rm_list);
int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc);
unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk);
unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk);
Loading