Commit d4d02d8b authored by David Howells's avatar David Howells
Browse files

rxrpc: Clone received jumbo subpackets and queue separately



Split up received jumbo packets into separate skbuffs by cloning the
original skbuff for each subpacket and setting the offset and length of the
data in that subpacket in the skbuff's private data.  The subpackets are
then placed on the recvmsg queue separately.  The security class then gets
to revise the offset and length to remove its metadata.

If we fail to clone a packet, we just drop it and let the peer resend it.
The original packet gets used for the final subpacket.

This should make it easier to handle parallel decryption of the subpackets.
It also simplifies the handling of lost or misordered packets in the
queuing/buffering loop as the possibility of overlapping jumbo packets no
longer needs to be considered.

Signed-off-by: default avatarDavid Howells <dhowells@redhat.com>
cc: Marc Dionne <marc.dionne@auristor.com>
cc: linux-afs@lists.infradead.org
parent faf92e8d
Loading
Loading
Loading
Loading
+5 −7
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
 */
#define rxrpc_skb_traces \
	EM(rxrpc_skb_cleaned,			"CLN") \
	EM(rxrpc_skb_cloned_jumbo,		"CLJ") \
	EM(rxrpc_skb_freed,			"FRE") \
	EM(rxrpc_skb_got,			"GOT") \
	EM(rxrpc_skb_lost,			"*L*") \
@@ -630,16 +631,15 @@ TRACE_EVENT(rxrpc_transmit,

TRACE_EVENT(rxrpc_rx_data,
	    TP_PROTO(unsigned int call, rxrpc_seq_t seq,
		     rxrpc_serial_t serial, u8 flags, u8 anno),
		     rxrpc_serial_t serial, u8 flags),

	    TP_ARGS(call, seq, serial, flags, anno),
	    TP_ARGS(call, seq, serial, flags),

	    TP_STRUCT__entry(
		    __field(unsigned int,		call		)
		    __field(rxrpc_seq_t,		seq		)
		    __field(rxrpc_serial_t,		serial		)
		    __field(u8,				flags		)
		    __field(u8,				anno		)
			     ),

	    TP_fast_assign(
@@ -647,15 +647,13 @@ TRACE_EVENT(rxrpc_rx_data,
		    __entry->seq = seq;
		    __entry->serial = serial;
		    __entry->flags = flags;
		    __entry->anno = anno;
			   ),

	    TP_printk("c=%08x DATA %08x q=%08x fl=%02x a=%02x",
	    TP_printk("c=%08x DATA %08x q=%08x fl=%02x",
		      __entry->call,
		      __entry->serial,
		      __entry->seq,
		      __entry->flags,
		      __entry->anno)
		      __entry->flags)
	    );

TRACE_EVENT(rxrpc_rx_ack,
+9 −23
Original line number Diff line number Diff line
@@ -195,17 +195,12 @@ struct rxrpc_host_header {
 * - max 48 bytes (struct sk_buff::cb)
 */
struct rxrpc_skb_priv {
	atomic_t	nr_ring_pins;		/* Number of rxtx ring pins */
	u8		nr_subpackets;		/* Number of subpackets */
	u16		remain;
	u16		offset;		/* Offset of data */
	u16		len;		/* Length of data */
	u8		rx_flags;	/* Received packet flags */
#define RXRPC_SKB_INCL_LAST	0x01		/* - Includes last packet */
	union {
		int		remain;		/* amount of space remaining for next write */

		/* List of requested ACKs on subpackets */
		unsigned long	rx_req_ack[(RXRPC_MAX_NR_JUMBO + BITS_PER_LONG - 1) /
					   BITS_PER_LONG];
	};
	u8		flags;
#define RXRPC_RX_VERIFIED	0x01

	struct rxrpc_host_header hdr;	/* RxRPC packet header from this packet */
};
@@ -252,16 +247,11 @@ struct rxrpc_security {
	int (*secure_packet)(struct rxrpc_call *, struct sk_buff *, size_t);

	/* verify the security on a received packet */
	int (*verify_packet)(struct rxrpc_call *, struct sk_buff *,
			     unsigned int, unsigned int, rxrpc_seq_t, u16);
	int (*verify_packet)(struct rxrpc_call *, struct sk_buff *);

	/* Free crypto request on a call */
	void (*free_call_crypto)(struct rxrpc_call *);

	/* Locate the data in a received packet that has been verified. */
	void (*locate_data)(struct rxrpc_call *, struct sk_buff *,
			    unsigned int *, unsigned int *);

	/* issue a challenge */
	int (*issue_challenge)(struct rxrpc_connection *);

@@ -628,7 +618,6 @@ struct rxrpc_call {
	int			debug_id;	/* debug ID for printks */
	unsigned short		rx_pkt_offset;	/* Current recvmsg packet offset */
	unsigned short		rx_pkt_len;	/* Current recvmsg packet len */
	bool			rx_pkt_last;	/* Current recvmsg packet is last */

	/* Rx/Tx circular buffer, depending on phase.
	 *
@@ -652,8 +641,6 @@ struct rxrpc_call {
#define RXRPC_TX_ANNO_LAST	0x04
#define RXRPC_TX_ANNO_RESENT	0x08

#define RXRPC_RX_ANNO_SUBPACKET	0x3f		/* Subpacket number in jumbogram */
#define RXRPC_RX_ANNO_VERIFIED	0x80		/* Set if verified and decrypted */
	rxrpc_seq_t		tx_hard_ack;	/* Dead slot in buffer; the first transmitted but
						 * not hard-ACK'd packet follows this.
						 */
@@ -681,7 +668,6 @@ struct rxrpc_call {
	rxrpc_serial_t		rx_serial;	/* Highest serial received for this call */
	u8			rx_winsize;	/* Size of Rx window */
	u8			tx_winsize;	/* Maximum size of Tx window */
	u8			nr_jumbo_bad;	/* Number of jumbo dups/exceeds-windows */

	spinlock_t		input_lock;	/* Lock for packet input to this call */

+185 −216
Original line number Diff line number Diff line
@@ -312,118 +312,18 @@ static bool rxrpc_receiving_reply(struct rxrpc_call *call)
	return rxrpc_end_tx_phase(call, true, "ETD");
}

/*
 * Scan a data packet to validate its structure and to work out how many
 * subpackets it contains.
 *
 * A jumbo packet is a collection of consecutive packets glued together with
 * little headers between that indicate how to change the initial header for
 * each subpacket.
 *
 * RXRPC_JUMBO_PACKET must be set on all but the last subpacket - and all but
 * the last are RXRPC_JUMBO_DATALEN in size.  The last subpacket may be of any
 * size.
 */
static bool rxrpc_validate_data(struct sk_buff *skb)
{
	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
	unsigned int offset = sizeof(struct rxrpc_wire_header);
	unsigned int len = skb->len;
	u8 flags = sp->hdr.flags;

	for (;;) {
		if (flags & RXRPC_REQUEST_ACK)
			__set_bit(sp->nr_subpackets, sp->rx_req_ack);
		sp->nr_subpackets++;

		if (!(flags & RXRPC_JUMBO_PACKET))
			break;

		if (len - offset < RXRPC_JUMBO_SUBPKTLEN)
			goto protocol_error;
		if (flags & RXRPC_LAST_PACKET)
			goto protocol_error;
		offset += RXRPC_JUMBO_DATALEN;
		if (skb_copy_bits(skb, offset, &flags, 1) < 0)
			goto protocol_error;
		offset += sizeof(struct rxrpc_jumbo_header);
	}

	if (flags & RXRPC_LAST_PACKET)
		sp->rx_flags |= RXRPC_SKB_INCL_LAST;
	return true;

protocol_error:
	return false;
}

/*
 * Handle reception of a duplicate packet.
 *
 * We have to take care to avoid an attack here whereby we're given a series of
 * jumbograms, each with a sequence number one before the preceding one and
 * filled up to maximum UDP size.  If they never send us the first packet in
 * the sequence, they can cause us to have to hold on to around 2MiB of kernel
 * space until the call times out.
 *
 * We limit the space usage by only accepting three duplicate jumbo packets per
 * call.  After that, we tell the other side we're no longer accepting jumbos
 * (that information is encoded in the ACK packet).
 */
static void rxrpc_input_dup_data(struct rxrpc_call *call, rxrpc_seq_t seq,
				 bool is_jumbo, bool *_jumbo_bad)
{
	/* Discard normal packets that are duplicates. */
	if (is_jumbo)
		return;

	/* Skip jumbo subpackets that are duplicates.  When we've had three or
	 * more partially duplicate jumbo packets, we refuse to take any more
	 * jumbos for this call.
	 */
	if (!*_jumbo_bad) {
		call->nr_jumbo_bad++;
		*_jumbo_bad = true;
	}
}

/*
 * Process a DATA packet, adding the packet to the Rx ring.  The caller's
 * packet ref must be passed on or discarded.
 */
static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
static void rxrpc_input_data_one(struct rxrpc_call *call, struct sk_buff *skb)
{
	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
	enum rxrpc_call_state state;
	unsigned int j, nr_subpackets, nr_unacked = 0;
	rxrpc_serial_t serial = sp->hdr.serial, ack_serial = serial;
	rxrpc_seq_t seq0 = sp->hdr.seq, hard_ack;
	bool jumbo_bad = false;

	_enter("{%u,%u},{%u,%u}",
	       call->rx_hard_ack, call->rx_top, skb->len, seq0);

	_proto("Rx DATA %%%u { #%u f=%02x n=%u }",
	       sp->hdr.serial, seq0, sp->hdr.flags, sp->nr_subpackets);

	state = READ_ONCE(call->state);
	if (state >= RXRPC_CALL_COMPLETE) {
		rxrpc_free_skb(skb, rxrpc_skb_freed);
		return;
	}

	if (state == RXRPC_CALL_SERVER_RECV_REQUEST) {
		unsigned long timo = READ_ONCE(call->next_req_timo);
		unsigned long now, expect_req_by;

		if (timo) {
			now = jiffies;
			expect_req_by = now + timo;
			WRITE_ONCE(call->expect_req_by, expect_req_by);
			rxrpc_reduce_call_timer(call, expect_req_by, now,
						rxrpc_timer_set_for_idle);
		}
	}
	rxrpc_serial_t serial = sp->hdr.serial;
	rxrpc_seq_t seq = sp->hdr.seq, hard_ack;
	unsigned int ix = seq & RXRPC_RXTX_BUFF_MASK;
	bool last = sp->hdr.flags & RXRPC_LAST_PACKET;
	bool acked = false;

	rxrpc_inc_stat(call->rxnet, stat_rx_data);
	if (sp->hdr.flags & RXRPC_REQUEST_ACK)
@@ -431,97 +331,52 @@ static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
	if (sp->hdr.flags & RXRPC_JUMBO_PACKET)
		rxrpc_inc_stat(call->rxnet, stat_rx_data_jumbo);

	spin_lock(&call->input_lock);

	/* Received data implicitly ACKs all of the request packets we sent
	 * when we're acting as a client.
	 */
	if ((state == RXRPC_CALL_CLIENT_SEND_REQUEST ||
	     state == RXRPC_CALL_CLIENT_AWAIT_REPLY) &&
	    !rxrpc_receiving_reply(call))
		goto unlock;

	hard_ack = READ_ONCE(call->rx_hard_ack);

	nr_subpackets = sp->nr_subpackets;
	if (nr_subpackets > 1) {
		if (call->nr_jumbo_bad > 3) {
			rxrpc_send_ACK(call, RXRPC_ACK_NOSPACE, serial,
				       rxrpc_propose_ack_input_data);
			goto unlock;
		}
	}

	for (j = 0; j < nr_subpackets; j++) {
		rxrpc_serial_t serial = sp->hdr.serial + j;
		rxrpc_seq_t seq = seq0 + j;
		unsigned int ix = seq & RXRPC_RXTX_BUFF_MASK;
		bool terminal = (j == nr_subpackets - 1);
		bool last = terminal && (sp->rx_flags & RXRPC_SKB_INCL_LAST);
		bool acked = false;
		u8 flags, annotation = j;

		_proto("Rx DATA+%u %%%u { #%x t=%u l=%u }",
		     j, serial, seq, terminal, last);
	_proto("Rx DATA %%%u { #%x l=%u }", serial, seq, last);

	if (last) {
		if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) &&
		    seq != call->rx_top) {
			rxrpc_proto_abort("LSN", call, seq);
				goto unlock;
			goto out;
		}
	} else {
		if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) &&
		    after_eq(seq, call->rx_top)) {
			rxrpc_proto_abort("LSA", call, seq);
				goto unlock;
			goto out;
		}
	}

		flags = 0;
		if (last)
			flags |= RXRPC_LAST_PACKET;
		if (!terminal)
			flags |= RXRPC_JUMBO_PACKET;
		if (test_bit(j, sp->rx_req_ack))
			flags |= RXRPC_REQUEST_ACK;
		trace_rxrpc_rx_data(call->debug_id, seq, serial, flags, annotation);
	trace_rxrpc_rx_data(call->debug_id, seq, serial, sp->hdr.flags);

	if (before_eq(seq, hard_ack)) {
		rxrpc_send_ACK(call, RXRPC_ACK_DUPLICATE, serial,
			       rxrpc_propose_ack_input_data);
			continue;
		goto out;
	}

	if (call->rxtx_buffer[ix]) {
			rxrpc_input_dup_data(call, seq, nr_subpackets > 1,
					     &jumbo_bad);
		rxrpc_send_ACK(call, RXRPC_ACK_DUPLICATE, serial,
			       rxrpc_propose_ack_input_data);
			continue;
		goto out;
	}

	if (after(seq, hard_ack + call->rx_winsize)) {
		rxrpc_send_ACK(call, RXRPC_ACK_EXCEEDS_WINDOW, serial,
			       rxrpc_propose_ack_input_data);
			if (flags & RXRPC_JUMBO_PACKET) {
				if (!jumbo_bad) {
					call->nr_jumbo_bad++;
					jumbo_bad = true;
				}
			}

			goto unlock;
		goto out;
	}

		if (flags & RXRPC_REQUEST_ACK) {
	if (sp->hdr.flags & RXRPC_REQUEST_ACK) {
		rxrpc_send_ACK(call, RXRPC_ACK_REQUESTED, serial,
			       rxrpc_propose_ack_input_data);
		acked = true;
	}

		if (after(seq0, call->ackr_highest_seq))
			call->ackr_highest_seq = seq0;
	if (after(seq, call->ackr_highest_seq))
		call->ackr_highest_seq = seq;

	/* Queue the packet.  We use a couple of memory barriers here as need
	 * to make sure that rx_top is perceived to be set after the buffer
@@ -531,9 +386,7 @@ static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
	 * Barriers against rxrpc_recvmsg_data() and rxrpc_rotate_rx_window()
	 * and also rxrpc_fill_out_ack().
	 */
		if (!terminal)
			rxrpc_get_skb(skb, rxrpc_skb_got);
		call->rxtx_annotations[ix] = annotation;
	call->rxtx_annotations[ix] = 1;
	smp_wmb();
	call->rxtx_buffer[ix] = skb;
	if (after(seq, call->rx_top)) {
@@ -547,14 +400,11 @@ static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
		}
	}

		if (terminal) {
			/* From this point on, we're not allowed to touch the
			 * packet any longer as its ref now belongs to the Rx
			 * ring.
	/* From this point on, we're not allowed to touch the packet any longer
	 * as its ref now belongs to the Rx ring.
	 */
	skb = NULL;
	sp = NULL;
		}

	if (last) {
		set_bit(RXRPC_CALL_RX_LAST, &call->flags);
@@ -573,23 +423,144 @@ static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
		call->rx_expect_next = seq + 1;
	}

		if (!acked) {
			nr_unacked++;
			ack_serial = serial;
		}
	}

unlock:
	if (atomic_add_return(nr_unacked, &call->ackr_nr_unacked) > 2)
		rxrpc_send_ACK(call, RXRPC_ACK_IDLE, ack_serial,
out:
	if (!acked &&
	    atomic_inc_return(&call->ackr_nr_unacked) > 2)
		rxrpc_send_ACK(call, RXRPC_ACK_IDLE, serial,
			       rxrpc_propose_ack_input_data);
	else
		rxrpc_propose_delay_ACK(call, ack_serial,
		rxrpc_propose_delay_ACK(call, serial,
					rxrpc_propose_ack_input_data);

	trace_rxrpc_notify_socket(call->debug_id, serial);
	rxrpc_notify_socket(call);

	rxrpc_free_skb(skb, rxrpc_skb_freed);
	_leave(" [queued]");
}

/*
 * Split a jumbo packet and file the bits separately.
 */
static bool rxrpc_input_split_jumbo(struct rxrpc_call *call, struct sk_buff *skb)
{
	struct rxrpc_jumbo_header jhdr;
	struct rxrpc_skb_priv *sp = rxrpc_skb(skb), *jsp;
	struct sk_buff *jskb;
	unsigned int offset = sizeof(struct rxrpc_wire_header);
	unsigned int len = skb->len - offset;

	while (sp->hdr.flags & RXRPC_JUMBO_PACKET) {
		if (len < RXRPC_JUMBO_SUBPKTLEN)
			goto protocol_error;
		if (sp->hdr.flags & RXRPC_LAST_PACKET)
			goto protocol_error;
		if (skb_copy_bits(skb, offset + RXRPC_JUMBO_DATALEN,
				  &jhdr, sizeof(jhdr)) < 0)
			goto protocol_error;

		jskb = skb_clone(skb, GFP_ATOMIC);
		if (!jskb) {
			kdebug("couldn't clone");
			return false;
		}
		rxrpc_new_skb(jskb, rxrpc_skb_cloned_jumbo);
		jsp = rxrpc_skb(jskb);
		jsp->offset = offset;
		jsp->len = RXRPC_JUMBO_DATALEN;
		rxrpc_input_data_one(call, jskb);

		sp->hdr.flags = jhdr.flags;
		sp->hdr._rsvd = ntohs(jhdr._rsvd);
		sp->hdr.seq++;
		sp->hdr.serial++;
		offset += RXRPC_JUMBO_SUBPKTLEN;
		len -= RXRPC_JUMBO_SUBPKTLEN;
	}

	sp->offset = offset;
	sp->len    = len;
	rxrpc_input_data_one(call, skb);
	return true;

protocol_error:
	return false;
}

/*
 * Process a DATA packet, adding the packet to the Rx ring.  The caller's
 * packet ref must be passed on or discarded.
 */
static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
{
	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
	enum rxrpc_call_state state;
	rxrpc_serial_t serial = sp->hdr.serial;
	rxrpc_seq_t seq0 = sp->hdr.seq;

	_enter("{%u,%u},{%u,%u}",
	       call->rx_hard_ack, call->rx_top, skb->len, seq0);

	_proto("Rx DATA %%%u { #%u f=%02x }",
	       sp->hdr.serial, seq0, sp->hdr.flags);

	state = READ_ONCE(call->state);
	if (state >= RXRPC_CALL_COMPLETE) {
		rxrpc_free_skb(skb, rxrpc_skb_freed);
		return;
	}

	/* Unshare the packet so that it can be modified for in-place
	 * decryption.
	 */
	if (sp->hdr.securityIndex != 0) {
		struct sk_buff *nskb = skb_unshare(skb, GFP_ATOMIC);
		if (!nskb) {
			rxrpc_eaten_skb(skb, rxrpc_skb_unshared_nomem);
			return;
		}

		if (nskb != skb) {
			rxrpc_eaten_skb(skb, rxrpc_skb_received);
			skb = nskb;
			rxrpc_new_skb(skb, rxrpc_skb_unshared);
			sp = rxrpc_skb(skb);
		}
	}

	if (state == RXRPC_CALL_SERVER_RECV_REQUEST) {
		unsigned long timo = READ_ONCE(call->next_req_timo);
		unsigned long now, expect_req_by;

		if (timo) {
			now = jiffies;
			expect_req_by = now + timo;
			WRITE_ONCE(call->expect_req_by, expect_req_by);
			rxrpc_reduce_call_timer(call, expect_req_by, now,
						rxrpc_timer_set_for_idle);
		}
	}

	spin_lock(&call->input_lock);

	/* Received data implicitly ACKs all of the request packets we sent
	 * when we're acting as a client.
	 */
	if ((state == RXRPC_CALL_CLIENT_SEND_REQUEST ||
	     state == RXRPC_CALL_CLIENT_AWAIT_REPLY) &&
	    !rxrpc_receiving_reply(call))
		goto out;

	if (!rxrpc_input_split_jumbo(call, skb)) {
		rxrpc_proto_abort("VLD", call, sp->hdr.seq);
		goto out;
	}
	skb = NULL;

out:
	trace_rxrpc_notify_socket(call->debug_id, serial);
	rxrpc_notify_socket(call);

	spin_unlock(&call->input_lock);
	rxrpc_free_skb(skb, rxrpc_skb_freed);
	_leave(" [queued]");
@@ -1288,8 +1259,6 @@ int rxrpc_input_packet(struct sock *udp_sk, struct sk_buff *skb)
		if (sp->hdr.callNumber == 0 ||
		    sp->hdr.seq == 0)
			goto bad_message;
		if (!rxrpc_validate_data(skb))
			goto bad_message;

		/* Unshare the packet so that it can be modified for in-place
		 * decryption.
@@ -1403,7 +1372,7 @@ int rxrpc_input_packet(struct sock *udp_sk, struct sk_buff *skb)
				trace_rxrpc_rx_data(chan->call_debug_id,
						    sp->hdr.seq,
						    sp->hdr.serial,
						    sp->hdr.flags, 0);
						    sp->hdr.flags);
			rxrpc_post_packet_to_conn(conn, skb);
			goto out;
		}
+4 −9
Original line number Diff line number Diff line
@@ -31,10 +31,11 @@ static int none_secure_packet(struct rxrpc_call *call, struct sk_buff *skb,
	return 0;
}

static int none_verify_packet(struct rxrpc_call *call, struct sk_buff *skb,
			      unsigned int offset, unsigned int len,
			      rxrpc_seq_t seq, u16 expected_cksum)
static int none_verify_packet(struct rxrpc_call *call, struct sk_buff *skb)
{
	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);

	sp->flags |= RXRPC_RX_VERIFIED;
	return 0;
}

@@ -42,11 +43,6 @@ static void none_free_call_crypto(struct rxrpc_call *call)
{
}

static void none_locate_data(struct rxrpc_call *call, struct sk_buff *skb,
			     unsigned int *_offset, unsigned int *_len)
{
}

static int none_respond_to_challenge(struct rxrpc_connection *conn,
				     struct sk_buff *skb,
				     u32 *_abort_code)
@@ -95,7 +91,6 @@ const struct rxrpc_security rxrpc_no_security = {
	.how_much_data			= none_how_much_data,
	.secure_packet			= none_secure_packet,
	.verify_packet			= none_verify_packet,
	.locate_data			= none_locate_data,
	.respond_to_challenge		= none_respond_to_challenge,
	.verify_response		= none_verify_response,
	.clear				= none_clear,
+1 −1
Original line number Diff line number Diff line
@@ -121,7 +121,7 @@ static size_t rxrpc_fill_out_ack(struct rxrpc_connection *conn,

	mtu = conn->params.peer->if_mtu;
	mtu -= conn->params.peer->hdrsize;
	jmax = (call->nr_jumbo_bad > 3) ? 1 : rxrpc_rx_jumbo_max;
	jmax = rxrpc_rx_jumbo_max;
	ackinfo.rxMTU		= htonl(rxrpc_rx_mtu);
	ackinfo.maxMTU		= htonl(mtu);
	ackinfo.rwind		= htonl(call->rx_winsize);
Loading