Commit 2d1faf7a authored by David Howells's avatar David Howells
Browse files

rxrpc: Simplify skbuff accounting in receive path



A received skbuff needs a ref when it gets put on a call data queue or conn
packet queue, and rxrpc_input_packet() and co. jump through a lot of hoops
to avoid double-dropping the skbuff ref so that we can avoid getting a ref
when we queue the packet.

Change this so that the skbuff ref is unconditionally dropped by the caller
of rxrpc_input_packet().  An additional ref is then taken on the packet if
it is pushed onto a queue.

Signed-off-by: default avatarDavid Howells <dhowells@redhat.com>
cc: Marc Dionne <marc.dionne@auristor.com>
cc: linux-afs@lists.infradead.org
parent 29fb4ec3
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -28,6 +28,8 @@
	EM(rxrpc_skb_eaten_by_unshare_nomem,	"ETN unshar-nm") \
	EM(rxrpc_skb_get_ack,			"GET ack      ") \
	EM(rxrpc_skb_get_conn_work,		"GET conn-work") \
	EM(rxrpc_skb_get_local_work,		"GET locl-work") \
	EM(rxrpc_skb_get_reject_work,		"GET rej-work ") \
	EM(rxrpc_skb_get_to_recvmsg,		"GET to-recv  ") \
	EM(rxrpc_skb_get_to_recvmsg_oos,	"GET to-recv-o") \
	EM(rxrpc_skb_new_encap_rcv,		"NEW encap-rcv") \
@@ -39,7 +41,6 @@
	EM(rxrpc_skb_put_error_report,		"PUT error-rep") \
	EM(rxrpc_skb_put_input,			"PUT input    ") \
	EM(rxrpc_skb_put_jumbo_subpacket,	"PUT jumbo-sub") \
	EM(rxrpc_skb_put_lose,			"PUT lose     ") \
	EM(rxrpc_skb_put_purge,			"PUT purge    ") \
	EM(rxrpc_skb_put_rotate,		"PUT rotate   ") \
	EM(rxrpc_skb_put_unknown,		"PUT unknown  ") \
+22 −23
Original line number Diff line number Diff line
@@ -338,7 +338,8 @@ static void rxrpc_input_queue_data(struct rxrpc_call *call, struct sk_buff *skb,
/*
 * Process a DATA packet.
 */
static void rxrpc_input_data_one(struct rxrpc_call *call, struct sk_buff *skb)
static void rxrpc_input_data_one(struct rxrpc_call *call, struct sk_buff *skb,
				 bool *_notify)
{
	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
	struct sk_buff *oos;
@@ -361,7 +362,7 @@ static void rxrpc_input_data_one(struct rxrpc_call *call, struct sk_buff *skb)
		if (test_and_set_bit(RXRPC_CALL_RX_LAST, &call->flags) &&
		    seq + 1 != wtop) {
			rxrpc_proto_abort("LSN", call, seq);
			goto err_free;
			return;
		}
	} else {
		if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) &&
@@ -369,7 +370,7 @@ static void rxrpc_input_data_one(struct rxrpc_call *call, struct sk_buff *skb)
			pr_warn("Packet beyond last: c=%x q=%x window=%x-%x wlimit=%x\n",
				call->debug_id, seq, window, wtop, wlimit);
			rxrpc_proto_abort("LSA", call, seq);
			goto err_free;
			return;
		}
	}

@@ -402,9 +403,11 @@ static void rxrpc_input_data_one(struct rxrpc_call *call, struct sk_buff *skb)
		if (after(window, wtop))
			wtop = window;

		rxrpc_get_skb(skb, rxrpc_skb_get_to_recvmsg);

		spin_lock(&call->recvmsg_queue.lock);
		rxrpc_input_queue_data(call, skb, window, wtop, rxrpc_receive_queue);
		skb = NULL;
		*_notify = true;

		while ((oos = skb_peek(&call->rx_oos_queue))) {
			struct rxrpc_skb_priv *osp = rxrpc_skb(oos);
@@ -456,16 +459,17 @@ static void rxrpc_input_data_one(struct rxrpc_call *call, struct sk_buff *skb)
			struct rxrpc_skb_priv *osp = rxrpc_skb(oos);

			if (after(osp->hdr.seq, seq)) {
				rxrpc_get_skb(skb, rxrpc_skb_get_to_recvmsg_oos);
				__skb_queue_before(&call->rx_oos_queue, oos, skb);
				goto oos_queued;
			}
		}

		rxrpc_get_skb(skb, rxrpc_skb_get_to_recvmsg_oos);
		__skb_queue_tail(&call->rx_oos_queue, skb);
	oos_queued:
		trace_rxrpc_receive(call, last ? rxrpc_receive_oos_last : rxrpc_receive_oos,
				    sp->hdr.serial, sp->hdr.seq);
		skb = NULL;
	}

send_ack:
@@ -483,9 +487,6 @@ static void rxrpc_input_data_one(struct rxrpc_call *call, struct sk_buff *skb)
	else
		rxrpc_propose_delay_ACK(call, serial,
					rxrpc_propose_ack_input_data);

err_free:
	rxrpc_free_skb(skb, rxrpc_skb_put_input);
}

/*
@@ -498,6 +499,7 @@ static bool rxrpc_input_split_jumbo(struct rxrpc_call *call, struct sk_buff *skb
	struct sk_buff *jskb;
	unsigned int offset = sizeof(struct rxrpc_wire_header);
	unsigned int len = skb->len - offset;
	bool notify = false;

	while (sp->hdr.flags & RXRPC_JUMBO_PACKET) {
		if (len < RXRPC_JUMBO_SUBPKTLEN)
@@ -517,7 +519,8 @@ static bool rxrpc_input_split_jumbo(struct rxrpc_call *call, struct sk_buff *skb
		jsp = rxrpc_skb(jskb);
		jsp->offset = offset;
		jsp->len = RXRPC_JUMBO_DATALEN;
		rxrpc_input_data_one(call, jskb);
		rxrpc_input_data_one(call, jskb, &notify);
		rxrpc_free_skb(jskb, rxrpc_skb_put_jumbo_subpacket);

		sp->hdr.flags = jhdr.flags;
		sp->hdr._rsvd = ntohs(jhdr._rsvd);
@@ -529,7 +532,11 @@ static bool rxrpc_input_split_jumbo(struct rxrpc_call *call, struct sk_buff *skb

	sp->offset = offset;
	sp->len    = len;
	rxrpc_input_data_one(call, skb);
	rxrpc_input_data_one(call, skb, &notify);
	if (notify) {
		trace_rxrpc_notify_socket(call->debug_id, sp->hdr.serial);
		rxrpc_notify_socket(call);
	}
	return true;

protocol_error:
@@ -552,10 +559,8 @@ static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
	       skb->len, seq0);

	state = READ_ONCE(call->state);
	if (state >= RXRPC_CALL_COMPLETE) {
		rxrpc_free_skb(skb, rxrpc_skb_put_input);
	if (state >= RXRPC_CALL_COMPLETE)
		return;
	}

	/* Unshare the packet so that it can be modified for in-place
	 * decryption.
@@ -605,7 +610,6 @@ static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
out:
	trace_rxrpc_notify_socket(call->debug_id, serial);
	rxrpc_notify_socket(call);
	rxrpc_free_skb(skb, rxrpc_skb_put_input);
	_leave(" [queued]");
}

@@ -797,7 +801,7 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb)
	struct rxrpc_ackpacket ack;
	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
	struct rxrpc_ackinfo info;
	struct sk_buff *skb_old = NULL, *skb_put = skb;
	struct sk_buff *skb_old = NULL;
	rxrpc_serial_t ack_serial, acked_serial;
	rxrpc_seq_t first_soft_ack, hard_ack, prev_pkt;
	int nr_acks, offset, ioffset;
@@ -963,6 +967,7 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb)
			goto out;
		}

		rxrpc_get_skb(skb, rxrpc_skb_get_ack);
		spin_lock(&call->acks_ack_lock);
		skb_old = call->acks_soft_tbl;
		call->acks_soft_tbl = skb;
@@ -970,7 +975,6 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb)

		rxrpc_input_soft_acks(call, skb->data + offset, first_soft_ack,
				      nr_acks, &summary);
		skb_put = NULL;
	} else if (call->acks_soft_tbl) {
		spin_lock(&call->acks_ack_lock);
		skb_old = call->acks_soft_tbl;
@@ -986,7 +990,6 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb)

	rxrpc_congestion_management(call, skb, &summary, acked_serial);
out:
	rxrpc_free_skb(skb_put, rxrpc_skb_put_input);
	rxrpc_free_skb(skb_old, rxrpc_skb_put_ack);
}

@@ -1037,11 +1040,11 @@ void rxrpc_input_call_event(struct rxrpc_call *call, struct sk_buff *skb)
	switch (sp->hdr.type) {
	case RXRPC_PACKET_TYPE_DATA:
		rxrpc_input_data(call, skb);
		goto no_free;
		break;

	case RXRPC_PACKET_TYPE_ACK:
		rxrpc_input_ack(call, skb);
		goto no_free;
		break;

	case RXRPC_PACKET_TYPE_BUSY:
		/* Just ignore BUSY packets from the server; the retry and
@@ -1061,10 +1064,6 @@ void rxrpc_input_call_event(struct rxrpc_call *call, struct sk_buff *skb)
	default:
		break;
	}

	rxrpc_free_skb(skb, rxrpc_skb_put_input);
no_free:
	_leave("");
}

/*
+32 −38
Original line number Diff line number Diff line
@@ -72,6 +72,7 @@ static void rxrpc_post_packet_to_conn(struct rxrpc_connection *conn,
{
	_enter("%p,%p", conn, skb);

	rxrpc_get_skb(skb, rxrpc_skb_get_conn_work);
	skb_queue_tail(&conn->rx_queue, skb);
	rxrpc_queue_conn(conn, rxrpc_conn_queue_rx_work);
}
@@ -86,10 +87,9 @@ static void rxrpc_post_packet_to_local(struct rxrpc_local *local,
	_enter("%p,%p", local, skb);

	if (rxrpc_get_local_maybe(local, rxrpc_local_get_queue)) {
		rxrpc_get_skb(skb, rxrpc_skb_get_local_work);
		skb_queue_tail(&local->event_queue, skb);
		rxrpc_queue_local(local);
	} else {
		rxrpc_free_skb(skb, rxrpc_skb_put_input);
	}
}

@@ -99,10 +99,9 @@ static void rxrpc_post_packet_to_local(struct rxrpc_local *local,
static void rxrpc_reject_packet(struct rxrpc_local *local, struct sk_buff *skb)
{
	if (rxrpc_get_local_maybe(local, rxrpc_local_get_queue)) {
		rxrpc_get_skb(skb, rxrpc_skb_get_reject_work);
		skb_queue_tail(&local->reject_queue, skb);
		rxrpc_queue_local(local);
	} else {
		rxrpc_free_skb(skb, rxrpc_skb_put_input);
	}
}

@@ -153,7 +152,7 @@ static bool rxrpc_extract_abort(struct sk_buff *skb)
/*
 * Process packets received on the local endpoint
 */
static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
{
	struct rxrpc_connection *conn;
	struct rxrpc_channel *chan;
@@ -161,6 +160,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
	struct rxrpc_skb_priv *sp;
	struct rxrpc_peer *peer = NULL;
	struct rxrpc_sock *rx = NULL;
	struct sk_buff *skb = *_skb;
	unsigned int channel;

	if (skb->tstamp == 0)
@@ -181,7 +181,6 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
		static int lose;
		if ((lose++ & 7) == 7) {
			trace_rxrpc_rx_lose(sp);
			rxrpc_free_skb(skb, rxrpc_skb_put_lose);
			return 0;
		}
	}
@@ -193,13 +192,13 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
	switch (sp->hdr.type) {
	case RXRPC_PACKET_TYPE_VERSION:
		if (rxrpc_to_client(sp))
			goto discard;
			return 0;
		rxrpc_post_packet_to_local(local, skb);
		goto out;
		return 0;

	case RXRPC_PACKET_TYPE_BUSY:
		if (rxrpc_to_server(sp))
			goto discard;
			return 0;
		fallthrough;
	case RXRPC_PACKET_TYPE_ACK:
	case RXRPC_PACKET_TYPE_ACKALL:
@@ -208,7 +207,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
		break;
	case RXRPC_PACKET_TYPE_ABORT:
		if (!rxrpc_extract_abort(skb))
			return true; /* Just discard if malformed */
			return 0; /* Just discard if malformed */
		break;

	case RXRPC_PACKET_TYPE_DATA:
@@ -220,15 +219,16 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
		 * decryption.
		 */
		if (sp->hdr.securityIndex != 0) {
			struct sk_buff *nskb = skb_unshare(skb, GFP_ATOMIC);
			if (!nskb) {
				rxrpc_eaten_skb(skb, rxrpc_skb_eaten_by_unshare_nomem);
				goto out;
			skb = skb_unshare(skb, GFP_ATOMIC);
			if (!skb) {
				rxrpc_eaten_skb(*_skb, rxrpc_skb_eaten_by_unshare_nomem);
				*_skb = NULL;
				return 0;
			}

			if (nskb != skb) {
				rxrpc_eaten_skb(skb, rxrpc_skb_eaten_by_unshare);
				skb = nskb;
			if (skb != *_skb) {
				rxrpc_eaten_skb(*_skb, rxrpc_skb_eaten_by_unshare);
				*_skb = skb;
				rxrpc_new_skb(skb, rxrpc_skb_new_unshared);
				sp = rxrpc_skb(skb);
			}
@@ -237,18 +237,18 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)

	case RXRPC_PACKET_TYPE_CHALLENGE:
		if (rxrpc_to_server(sp))
			goto discard;
			return 0;
		break;
	case RXRPC_PACKET_TYPE_RESPONSE:
		if (rxrpc_to_client(sp))
			goto discard;
			return 0;
		break;

		/* Packet types 9-11 should just be ignored. */
	case RXRPC_PACKET_TYPE_PARAMS:
	case RXRPC_PACKET_TYPE_10:
	case RXRPC_PACKET_TYPE_11:
		goto discard;
		return 0;

	default:
		goto bad_message;
@@ -268,7 +268,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
			if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA &&
			    sp->hdr.seq == 1)
				goto unsupported_service;
			goto discard;
			return 0;
		}
	}

@@ -294,7 +294,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
			/* Connection-level packet */
			_debug("CONN %p {%d}", conn, conn->debug_id);
			rxrpc_post_packet_to_conn(conn, skb);
			goto out;
			return 0;
		}

		if ((int)sp->hdr.serial - (int)conn->hi_serial > 0)
@@ -306,19 +306,19 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)

		/* Ignore really old calls */
		if (sp->hdr.callNumber < chan->last_call)
			goto discard;
			return 0;

		if (sp->hdr.callNumber == chan->last_call) {
			if (chan->call ||
			    sp->hdr.type == RXRPC_PACKET_TYPE_ABORT)
				goto discard;
				return 0;

			/* For the previous service call, if completed
			 * successfully, we discard all further packets.
			 */
			if (rxrpc_conn_is_service(conn) &&
			    chan->last_type == RXRPC_PACKET_TYPE_ACK)
				goto discard;
				return 0;

			/* But otherwise we need to retransmit the final packet
			 * from data cached in the connection record.
@@ -329,7 +329,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
						    sp->hdr.serial,
						    sp->hdr.flags);
			rxrpc_post_packet_to_conn(conn, skb);
			goto out;
			return 0;
		}

		call = rcu_dereference(chan->call);
@@ -357,21 +357,14 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
		    sp->hdr.type != RXRPC_PACKET_TYPE_DATA)
			goto bad_message;
		if (sp->hdr.seq != 1)
			goto discard;
			return 0;
		call = rxrpc_new_incoming_call(local, rx, skb);
		if (!call)
			goto reject_packet;
	}

	/* Process a call packet; this either discards or passes on the ref
	 * elsewhere.
	 */
	/* Process a call packet. */
	rxrpc_input_call_event(call, skb);
	goto out;

discard:
	rxrpc_free_skb(skb, rxrpc_skb_put_input);
out:
	trace_rxrpc_rx_done(0, 0);
	return 0;

@@ -400,9 +393,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff *skb)
post_abort:
	skb->mark = RXRPC_SKB_MARK_REJECT_ABORT;
reject_packet:
	trace_rxrpc_rx_done(skb->mark, skb->priority);
	rxrpc_reject_packet(local, skb);
	_leave(" [badmsg]");
	return 0;
}

@@ -441,9 +432,12 @@ int rxrpc_io_thread(void *data)
		if ((skb = __skb_dequeue(&rx_queue))) {
			switch (skb->mark) {
			case RXRPC_SKB_MARK_PACKET:
				skb->priority = 0;
				rcu_read_lock();
				rxrpc_input_packet(local, skb);
				rxrpc_input_packet(local, &skb);
				rcu_read_unlock();
				trace_rxrpc_rx_done(skb->mark, skb->priority);
				rxrpc_free_skb(skb, rxrpc_skb_put_input);
				break;
			case RXRPC_SKB_MARK_ERROR:
				rxrpc_input_error(local, skb);