Commit 5e052dda authored by Chuck Lever's avatar Chuck Lever
Browse files

SUNRPC: Recognize control messages in server-side TCP socket code



To support kTLS, the server-side TCP socket receive path needs to
watch for CMSGs.

Acked-by: default avatarJakub Kicinski <kuba@kernel.org>
Signed-off-by: default avatarChuck Lever <chuck.lever@oracle.com>
parent 6a0cdf56
Loading
Loading
Loading
Loading
+2 −0
Original line number Original line Diff line number Diff line
@@ -69,6 +69,8 @@ extern const struct tls_cipher_size_desc tls_cipher_size_desc[];


#define TLS_CRYPTO_INFO_READY(info)	((info)->cipher_type)
#define TLS_CRYPTO_INFO_READY(info)	((info)->cipher_type)


#define TLS_RECORD_TYPE_ALERT		0x15
#define TLS_RECORD_TYPE_HANDSHAKE	0x16
#define TLS_RECORD_TYPE_DATA		0x17
#define TLS_RECORD_TYPE_DATA		0x17


#define TLS_AAD_SPACE_SIZE		13
#define TLS_AAD_SPACE_SIZE		13
+46 −2
Original line number Original line Diff line number Diff line
@@ -43,6 +43,7 @@
#include <net/udp.h>
#include <net/udp.h>
#include <net/tcp.h>
#include <net/tcp.h>
#include <net/tcp_states.h>
#include <net/tcp_states.h>
#include <net/tls.h>
#include <linux/uaccess.h>
#include <linux/uaccess.h>
#include <linux/highmem.h>
#include <linux/highmem.h>
#include <asm/ioctls.h>
#include <asm/ioctls.h>
@@ -216,6 +217,49 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining)
	return len;
	return len;
}
}


static int
svc_tcp_sock_process_cmsg(struct svc_sock *svsk, struct msghdr *msg,
			  struct cmsghdr *cmsg, int ret)
{
	if (cmsg->cmsg_level == SOL_TLS &&
	    cmsg->cmsg_type == TLS_GET_RECORD_TYPE) {
		u8 content_type = *((u8 *)CMSG_DATA(cmsg));

		switch (content_type) {
		case TLS_RECORD_TYPE_DATA:
			/* TLS sets EOR at the end of each application data
			 * record, even though there might be more frames
			 * waiting to be decrypted.
			 */
			msg->msg_flags &= ~MSG_EOR;
			break;
		case TLS_RECORD_TYPE_ALERT:
			ret = -ENOTCONN;
			break;
		default:
			ret = -EAGAIN;
		}
	}
	return ret;
}

static int
svc_tcp_sock_recv_cmsg(struct svc_sock *svsk, struct msghdr *msg)
{
	union {
		struct cmsghdr	cmsg;
		u8		buf[CMSG_SPACE(sizeof(u8))];
	} u;
	int ret;

	msg->msg_control = &u;
	msg->msg_controllen = sizeof(u);
	ret = sock_recvmsg(svsk->sk_sock, msg, MSG_DONTWAIT);
	if (unlikely(msg->msg_controllen != sizeof(u)))
		ret = svc_tcp_sock_process_cmsg(svsk, msg, &u.cmsg, ret);
	return ret;
}

#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
static void svc_flush_bvec(const struct bio_vec *bvec, size_t size, size_t seek)
static void svc_flush_bvec(const struct bio_vec *bvec, size_t size, size_t seek)
{
{
@@ -263,7 +307,7 @@ static ssize_t svc_tcp_read_msg(struct svc_rqst *rqstp, size_t buflen,
		iov_iter_advance(&msg.msg_iter, seek);
		iov_iter_advance(&msg.msg_iter, seek);
		buflen -= seek;
		buflen -= seek;
	}
	}
	len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT);
	len = svc_tcp_sock_recv_cmsg(svsk, &msg);
	if (len > 0)
	if (len > 0)
		svc_flush_bvec(bvec, len, seek);
		svc_flush_bvec(bvec, len, seek);


@@ -877,7 +921,7 @@ static ssize_t svc_tcp_read_marker(struct svc_sock *svsk,
		iov.iov_base = ((char *)&svsk->sk_marker) + svsk->sk_tcplen;
		iov.iov_base = ((char *)&svsk->sk_marker) + svsk->sk_tcplen;
		iov.iov_len  = want;
		iov.iov_len  = want;
		iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, want);
		iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, want);
		len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT);
		len = svc_tcp_sock_recv_cmsg(svsk, &msg);
		if (len < 0)
		if (len < 0)
			return len;
			return len;
		svsk->sk_tcplen += len;
		svsk->sk_tcplen += len;