Commit 75eb6af7 authored by Chuck Lever's avatar Chuck Lever Committed by Trond Myklebust
Browse files

SUNRPC: Add a TCP-with-TLS RPC transport class



Use the new TLS handshake API to enable the SunRPC client code
to request a TLS handshake. This implements support for RFC 9289,
only on TCP sockets.

Upper layers such as NFS use RPC-with-TLS to protect in-transit
traffic.

Signed-off-by: default avatarChuck Lever <chuck.lever@oracle.com>
Signed-off-by: default avatarTrond Myklebust <trond.myklebust@hammerspace.com>
parent dea034b9
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -200,6 +200,7 @@ enum xprt_transports {
	XPRT_TRANSPORT_RDMA	= 256,
	XPRT_TRANSPORT_BC_RDMA	= XPRT_TRANSPORT_RDMA | XPRT_TRANSPORT_BC,
	XPRT_TRANSPORT_LOCAL	= 257,
	XPRT_TRANSPORT_TCP_TLS	= 258,
};

struct rpc_sysfs_xprt;
+2 −0
Original line number Diff line number Diff line
@@ -57,9 +57,11 @@ struct sock_xprt {
	struct work_struct	error_worker;
	struct work_struct	recv_worker;
	struct mutex		recv_mutex;
	struct completion	handshake_done;
	struct sockaddr_storage	srcaddr;
	unsigned short		srcport;
	int			xprt_err;
	struct rpc_clnt		*clnt;

	/*
	 * UDP socket buffer size parameters
+44 −0
Original line number Diff line number Diff line
@@ -1525,6 +1525,50 @@ TRACE_EVENT(rpcb_unregister,
	)
);

/**
 ** RPC-over-TLS tracepoints
 **/

DECLARE_EVENT_CLASS(rpc_tls_class,
	TP_PROTO(
		const struct rpc_clnt *clnt,
		const struct rpc_xprt *xprt
	),

	TP_ARGS(clnt, xprt),

	TP_STRUCT__entry(
		__field(unsigned long, requested_policy)
		__field(u32, version)
		__string(servername, xprt->servername)
		__string(progname, clnt->cl_program->name)
	),

	TP_fast_assign(
		__entry->requested_policy = clnt->cl_xprtsec.policy;
		__entry->version = clnt->cl_vers;
		__assign_str(servername, xprt->servername);
		__assign_str(progname, clnt->cl_program->name)
	),

	TP_printk("server=%s %sv%u requested_policy=%s",
		__get_str(servername), __get_str(progname), __entry->version,
		rpc_show_xprtsec_policy(__entry->requested_policy)
	)
);

#define DEFINE_RPC_TLS_EVENT(name) \
	DEFINE_EVENT(rpc_tls_class, rpc_tls_##name, \
			TP_PROTO( \
				const struct rpc_clnt *clnt, \
				const struct rpc_xprt *xprt \
			), \
			TP_ARGS(clnt, xprt))

DEFINE_RPC_TLS_EVENT(unavailable);
DEFINE_RPC_TLS_EVENT(not_started);


/* Record an xdr_buf containing a fully-formed RPC message */
DECLARE_EVENT_CLASS(svc_xdr_msg_class,
	TP_PROTO(
+1 −0
Original line number Diff line number Diff line
@@ -239,6 +239,7 @@ static ssize_t rpc_sysfs_xprt_dstaddr_store(struct kobject *kobj,
	if (!xprt)
		return 0;
	if (!(xprt->xprt_class->ident == XPRT_TRANSPORT_TCP ||
	      xprt->xprt_class->ident == XPRT_TRANSPORT_TCP_TLS ||
	      xprt->xprt_class->ident == XPRT_TRANSPORT_RDMA)) {
		xprt_put(xprt);
		return -EOPNOTSUPP;
+370 −0
Original line number Diff line number Diff line
@@ -48,6 +48,7 @@
#include <net/udp.h>
#include <net/tcp.h>
#include <net/tls.h>
#include <net/handshake.h>

#include <linux/bvec.h>
#include <linux/highmem.h>
@@ -98,6 +99,7 @@ static struct ctl_table_header *sunrpc_table_header;
static struct xprt_class xs_local_transport;
static struct xprt_class xs_udp_transport;
static struct xprt_class xs_tcp_transport;
static struct xprt_class xs_tcp_tls_transport;
static struct xprt_class xs_bc_tcp_transport;

/*
@@ -189,6 +191,11 @@ static struct ctl_table xs_tunables_table[] = {
 */
#define XS_IDLE_DISC_TO		(5U * 60 * HZ)

/*
 * TLS handshake timeout.
 */
#define XS_TLS_HANDSHAKE_TO	(10U * HZ)

#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
# undef  RPC_DEBUG_DATA
# define RPCDBG_FACILITY	RPCDBG_TRANS
@@ -1243,6 +1250,8 @@ static void xs_reset_transport(struct sock_xprt *transport)
	if (atomic_read(&transport->xprt.swapper))
		sk_clear_memalloc(sk);

	tls_handshake_cancel(sk);

	kernel_sock_shutdown(sock, SHUT_RDWR);

	mutex_lock(&transport->recv_mutex);
@@ -2416,6 +2425,267 @@ static void xs_tcp_setup_socket(struct work_struct *work)
	current_restore_flags(pflags, PF_MEMALLOC);
}

/*
 * Transfer the connected socket to @upper_transport, then mark that
 * xprt CONNECTED.
 */
static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt,
					struct sock_xprt *upper_transport)
{
	struct sock_xprt *lower_transport =
			container_of(lower_xprt, struct sock_xprt, xprt);
	struct rpc_xprt *upper_xprt = &upper_transport->xprt;

	if (!upper_transport->inet) {
		struct socket *sock = lower_transport->sock;
		struct sock *sk = sock->sk;

		/* Avoid temporary address, they are bad for long-lived
		 * connections such as NFS mounts.
		 * RFC4941, section 3.6 suggests that:
		 *    Individual applications, which have specific
		 *    knowledge about the normal duration of connections,
		 *    MAY override this as appropriate.
		 */
		if (xs_addr(upper_xprt)->sa_family == PF_INET6)
			ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC);

		xs_tcp_set_socket_timeouts(upper_xprt, sock);
		tcp_sock_set_nodelay(sk);

		lock_sock(sk);

		/* @sk is already connected, so it now has the RPC callbacks.
		 * Reach into @lower_transport to save the original ones.
		 */
		upper_transport->old_data_ready = lower_transport->old_data_ready;
		upper_transport->old_state_change = lower_transport->old_state_change;
		upper_transport->old_write_space = lower_transport->old_write_space;
		upper_transport->old_error_report = lower_transport->old_error_report;
		sk->sk_user_data = upper_xprt;

		/* socket options */
		sock_reset_flag(sk, SOCK_LINGER);

		xprt_clear_connected(upper_xprt);

		upper_transport->sock = sock;
		upper_transport->inet = sk;
		upper_transport->file = lower_transport->file;

		release_sock(sk);

		/* Reset lower_transport before shutting down its clnt */
		mutex_lock(&lower_transport->recv_mutex);
		lower_transport->inet = NULL;
		lower_transport->sock = NULL;
		lower_transport->file = NULL;

		xprt_clear_connected(lower_xprt);
		xs_sock_reset_connection_flags(lower_xprt);
		xs_stream_reset_connect(lower_transport);
		mutex_unlock(&lower_transport->recv_mutex);
	}

	if (!xprt_bound(upper_xprt))
		return -ENOTCONN;

	xs_set_memalloc(upper_xprt);

	if (!xprt_test_and_set_connected(upper_xprt)) {
		upper_xprt->connect_cookie++;
		clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
		xprt_clear_connecting(upper_xprt);

		upper_xprt->stat.connect_count++;
		upper_xprt->stat.connect_time += (long)jiffies -
					   upper_xprt->stat.connect_start;
		xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
	}
	return 0;
}

/**
 * xs_tls_handshake_done - TLS handshake completion handler
 * @data: address of xprt to wake
 * @status: status of handshake
 * @peerid: serial number of key containing the remote's identity
 *
 */
static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid)
{
	struct rpc_xprt *lower_xprt = data;
	struct sock_xprt *lower_transport =
				container_of(lower_xprt, struct sock_xprt, xprt);

	lower_transport->xprt_err = status ? -EACCES : 0;
	complete(&lower_transport->handshake_done);
	xprt_put(lower_xprt);
}

static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec)
{
	struct sock_xprt *lower_transport =
				container_of(lower_xprt, struct sock_xprt, xprt);
	struct tls_handshake_args args = {
		.ta_sock	= lower_transport->sock,
		.ta_done	= xs_tls_handshake_done,
		.ta_data	= xprt_get(lower_xprt),
		.ta_peername	= lower_xprt->servername,
	};
	struct sock *sk = lower_transport->inet;
	int rc;

	init_completion(&lower_transport->handshake_done);
	set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
	lower_transport->xprt_err = -ETIMEDOUT;
	switch (xprtsec->policy) {
	case RPC_XPRTSEC_TLS_ANON:
		rc = tls_client_hello_anon(&args, GFP_KERNEL);
		if (rc)
			goto out_put_xprt;
		break;
	case RPC_XPRTSEC_TLS_X509:
		args.ta_my_cert = xprtsec->cert_serial;
		args.ta_my_privkey = xprtsec->privkey_serial;
		rc = tls_client_hello_x509(&args, GFP_KERNEL);
		if (rc)
			goto out_put_xprt;
		break;
	default:
		rc = -EACCES;
		goto out_put_xprt;
	}

	rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done,
						       XS_TLS_HANDSHAKE_TO);
	if (rc <= 0) {
		if (!tls_handshake_cancel(sk)) {
			if (rc == 0)
				rc = -ETIMEDOUT;
			goto out_put_xprt;
		}
	}

	rc = lower_transport->xprt_err;

out:
	xs_stream_reset_connect(lower_transport);
	clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
	return rc;

out_put_xprt:
	xprt_put(lower_xprt);
	goto out;
}

/**
 * xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket
 * @work: queued work item
 *
 * Invoked by a work queue tasklet.
 *
 * For RPC-with-TLS, there is a two-stage connection process.
 *
 * The "upper-layer xprt" is visible to the RPC consumer. Once it has
 * been marked connected, the consumer knows that a TCP connection and
 * a TLS session have been established.
 *
 * A "lower-layer xprt", created in this function, handles the mechanics
 * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and
 * then driving the TLS handshake. Once all that is complete, the upper
 * layer xprt is marked connected.
 */
static void xs_tcp_tls_setup_socket(struct work_struct *work)
{
	struct sock_xprt *upper_transport =
		container_of(work, struct sock_xprt, connect_worker.work);
	struct rpc_clnt *upper_clnt = upper_transport->clnt;
	struct rpc_xprt *upper_xprt = &upper_transport->xprt;
	struct rpc_create_args args = {
		.net		= upper_xprt->xprt_net,
		.protocol	= upper_xprt->prot,
		.address	= (struct sockaddr *)&upper_xprt->addr,
		.addrsize	= upper_xprt->addrlen,
		.timeout	= upper_clnt->cl_timeout,
		.servername	= upper_xprt->servername,
		.program	= upper_clnt->cl_program,
		.prognumber	= upper_clnt->cl_prog,
		.version	= upper_clnt->cl_vers,
		.authflavor	= RPC_AUTH_TLS,
		.cred		= upper_clnt->cl_cred,
		.xprtsec	= {
			.policy		= RPC_XPRTSEC_NONE,
		},
	};
	unsigned int pflags = current->flags;
	struct rpc_clnt *lower_clnt;
	struct rpc_xprt *lower_xprt;
	int status;

	if (atomic_read(&upper_xprt->swapper))
		current->flags |= PF_MEMALLOC;

	xs_stream_start_connect(upper_transport);

	/* This implicitly sends an RPC_AUTH_TLS probe */
	lower_clnt = rpc_create(&args);
	if (IS_ERR(lower_clnt)) {
		trace_rpc_tls_unavailable(upper_clnt, upper_xprt);
		clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
		xprt_clear_connecting(upper_xprt);
		xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt));
		xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
		goto out_unlock;
	}

	/* RPC_AUTH_TLS probe was successful. Try a TLS handshake on
	 * the lower xprt.
	 */
	rcu_read_lock();
	lower_xprt = rcu_dereference(lower_clnt->cl_xprt);
	rcu_read_unlock();
	status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec);
	if (status) {
		trace_rpc_tls_not_started(upper_clnt, upper_xprt);
		goto out_close;
	}

	status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport);
	if (status)
		goto out_close;

	trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0);
	if (!xprt_test_and_set_connected(upper_xprt)) {
		upper_xprt->connect_cookie++;
		clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
		xprt_clear_connecting(upper_xprt);

		upper_xprt->stat.connect_count++;
		upper_xprt->stat.connect_time += (long)jiffies -
					   upper_xprt->stat.connect_start;
		xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
	}
	rpc_shutdown_client(lower_clnt);

out_unlock:
	current_restore_flags(pflags, PF_MEMALLOC);
	upper_transport->clnt = NULL;
	xprt_unlock_connect(upper_xprt, upper_transport);
	return;

out_close:
	rpc_shutdown_client(lower_clnt);

	/* xprt_force_disconnect() wakes tasks with a fixed tk_status code.
	 * Wake them first here to ensure they get our tk_status code.
	 */
	xprt_wake_pending_tasks(upper_xprt, status);
	xs_tcp_force_close(upper_xprt);
	xprt_clear_connecting(upper_xprt);
	goto out_unlock;
}

/**
 * xs_connect - connect a socket to a remote endpoint
 * @xprt: pointer to transport structure
@@ -2447,6 +2717,7 @@ static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task)
	} else
		dprintk("RPC:       xs_connect scheduled xprt %p\n", xprt);

	transport->clnt = task->tk_client;
	queue_delayed_work(xprtiod_workqueue,
			&transport->connect_worker,
			delay);
@@ -3100,6 +3371,94 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args)
	return ret;
}

/**
 * xs_setup_tcp_tls - Set up transport to use a TCP with TLS
 * @args: rpc transport creation arguments
 *
 */
static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args)
{
	struct sockaddr *addr = args->dstaddr;
	struct rpc_xprt *xprt;
	struct sock_xprt *transport;
	struct rpc_xprt *ret;
	unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries;

	if (args->flags & XPRT_CREATE_INFINITE_SLOTS)
		max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT;

	xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries,
			     max_slot_table_size);
	if (IS_ERR(xprt))
		return xprt;
	transport = container_of(xprt, struct sock_xprt, xprt);

	xprt->prot = IPPROTO_TCP;
	xprt->xprt_class = &xs_tcp_transport;
	xprt->max_payload = RPC_MAX_FRAGMENT_SIZE;

	xprt->bind_timeout = XS_BIND_TO;
	xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
	xprt->idle_timeout = XS_IDLE_DISC_TO;

	xprt->ops = &xs_tcp_ops;
	xprt->timeout = &xs_tcp_default_timeout;

	xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
	xprt->connect_timeout = xprt->timeout->to_initval *
		(xprt->timeout->to_retries + 1);

	INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
	INIT_WORK(&transport->error_worker, xs_error_handle);

	switch (args->xprtsec.policy) {
	case RPC_XPRTSEC_TLS_ANON:
	case RPC_XPRTSEC_TLS_X509:
		xprt->xprtsec = args->xprtsec;
		INIT_DELAYED_WORK(&transport->connect_worker,
				  xs_tcp_tls_setup_socket);
		break;
	default:
		ret = ERR_PTR(-EACCES);
		goto out_err;
	}

	switch (addr->sa_family) {
	case AF_INET:
		if (((struct sockaddr_in *)addr)->sin_port != htons(0))
			xprt_set_bound(xprt);

		xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP);
		break;
	case AF_INET6:
		if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0))
			xprt_set_bound(xprt);

		xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6);
		break;
	default:
		ret = ERR_PTR(-EAFNOSUPPORT);
		goto out_err;
	}

	if (xprt_bound(xprt))
		dprintk("RPC:       set up xprt to %s (port %s) via %s\n",
			xprt->address_strings[RPC_DISPLAY_ADDR],
			xprt->address_strings[RPC_DISPLAY_PORT],
			xprt->address_strings[RPC_DISPLAY_PROTO]);
	else
		dprintk("RPC:       set up xprt to %s (autobind) via %s\n",
			xprt->address_strings[RPC_DISPLAY_ADDR],
			xprt->address_strings[RPC_DISPLAY_PROTO]);

	if (try_module_get(THIS_MODULE))
		return xprt;
	ret = ERR_PTR(-EINVAL);
out_err:
	xs_xprt_free(xprt);
	return ret;
}

/**
 * xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket
 * @args: rpc transport creation arguments
@@ -3209,6 +3568,15 @@ static struct xprt_class xs_tcp_transport = {
	.netid		= { "tcp", "tcp6", "" },
};

static struct xprt_class	xs_tcp_tls_transport = {
	.list		= LIST_HEAD_INIT(xs_tcp_tls_transport.list),
	.name		= "tcp-with-tls",
	.owner		= THIS_MODULE,
	.ident		= XPRT_TRANSPORT_TCP_TLS,
	.setup		= xs_setup_tcp_tls,
	.netid		= { "tcp", "tcp6", "" },
};

static struct xprt_class	xs_bc_tcp_transport = {
	.list		= LIST_HEAD_INIT(xs_bc_tcp_transport.list),
	.name		= "tcp NFSv4.1 backchannel",
@@ -3230,6 +3598,7 @@ int init_socket_xprt(void)
	xprt_register_transport(&xs_local_transport);
	xprt_register_transport(&xs_udp_transport);
	xprt_register_transport(&xs_tcp_transport);
	xprt_register_transport(&xs_tcp_tls_transport);
	xprt_register_transport(&xs_bc_tcp_transport);

	return 0;
@@ -3249,6 +3618,7 @@ void cleanup_socket_xprt(void)
	xprt_unregister_transport(&xs_local_transport);
	xprt_unregister_transport(&xs_udp_transport);
	xprt_unregister_transport(&xs_tcp_transport);
	xprt_unregister_transport(&xs_tcp_tls_transport);
	xprt_unregister_transport(&xs_bc_tcp_transport);
}