Commit 49573ff7 authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

Merge branch 'tls-splice_read-fixes'

Jakub Kicinski says:

====================
tls: splice_read fixes

As I work my way to unlocked and zero-copy TLS Rx the obvious bugs
in the splice_read implementation get harder and harder to ignore.
This is to say the fixes here are discovered by code inspection,
I'm not aware of anyone actually using splice_read.
====================

Link: https://lore.kernel.org/r/20211124232557.2039757-1-kuba@kernel.org


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 9dbe33cf f884a342
Loading
Loading
Loading
Loading
+40 −7
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ static DEFINE_MUTEX(tcpv6_prot_mutex);
static const struct proto *saved_tcpv4_prot;
static DEFINE_MUTEX(tcpv4_prot_mutex);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static struct proto_ops tls_sw_proto_ops;
static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
			 const struct proto *base);

@@ -71,6 +71,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx)

	WRITE_ONCE(sk->sk_prot,
		   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
	WRITE_ONCE(sk->sk_socket->ops,
		   &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
}

int wait_on_pending_writer(struct sock *sk, long *timeo)
@@ -669,8 +671,6 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
	if (tx) {
		ctx->sk_write_space = sk->sk_write_space;
		sk->sk_write_space = tls_write_space;
	} else {
		sk->sk_socket->ops = &tls_sw_proto_ops;
	}
	goto out;

@@ -728,6 +728,39 @@ struct tls_context *tls_ctx_create(struct sock *sk)
	return ctx;
}

static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
			    const struct proto_ops *base)
{
	ops[TLS_BASE][TLS_BASE] = *base;

	ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
	ops[TLS_SW  ][TLS_BASE].sendpage_locked	= tls_sw_sendpage_locked;

	ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
	ops[TLS_BASE][TLS_SW  ].splice_read	= tls_sw_splice_read;

	ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
	ops[TLS_SW  ][TLS_SW  ].splice_read	= tls_sw_splice_read;

#ifdef CONFIG_TLS_DEVICE
	ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
	ops[TLS_HW  ][TLS_BASE].sendpage_locked	= NULL;

	ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
	ops[TLS_HW  ][TLS_SW  ].sendpage_locked	= NULL;

	ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];

	ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];

	ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
	ops[TLS_HW  ][TLS_HW  ].sendpage_locked	= NULL;
#endif
#ifdef CONFIG_TLS_TOE
	ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
#endif
}

static void tls_build_proto(struct sock *sk)
{
	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
@@ -739,6 +772,8 @@ static void tls_build_proto(struct sock *sk)
		mutex_lock(&tcpv6_prot_mutex);
		if (likely(prot != saved_tcpv6_prot)) {
			build_protos(tls_prots[TLSV6], prot);
			build_proto_ops(tls_proto_ops[TLSV6],
					sk->sk_socket->ops);
			smp_store_release(&saved_tcpv6_prot, prot);
		}
		mutex_unlock(&tcpv6_prot_mutex);
@@ -749,6 +784,8 @@ static void tls_build_proto(struct sock *sk)
		mutex_lock(&tcpv4_prot_mutex);
		if (likely(prot != saved_tcpv4_prot)) {
			build_protos(tls_prots[TLSV4], prot);
			build_proto_ops(tls_proto_ops[TLSV4],
					sk->sk_socket->ops);
			smp_store_release(&saved_tcpv4_prot, prot);
		}
		mutex_unlock(&tcpv4_prot_mutex);
@@ -959,10 +996,6 @@ static int __init tls_register(void)
	if (err)
		return err;

	tls_sw_proto_ops = inet_stream_ops;
	tls_sw_proto_ops.splice_read = tls_sw_splice_read;
	tls_sw_proto_ops.sendpage_locked   = tls_sw_sendpage_locked;

	tls_device_init();
	tcp_register_ulp(&tcp_tls_ulp_ops);

+27 −13
Original line number Diff line number Diff line
@@ -2005,6 +2005,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
	struct sock *sk = sock->sk;
	struct sk_buff *skb;
	ssize_t copied = 0;
	bool from_queue;
	int err = 0;
	long timeo;
	int chunk;
@@ -2014,12 +2015,21 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,

	timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);

	skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err);
	from_queue = !skb_queue_empty(&ctx->rx_list);
	if (from_queue) {
		skb = __skb_dequeue(&ctx->rx_list);
	} else {
		skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
				    &err);
		if (!skb)
			goto splice_read_end;

	if (!ctx->decrypted) {
		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
		if (err < 0) {
			tls_err_abort(sk, -EBADMSG);
			goto splice_read_end;
		}
	}

	/* splice does not support reading control messages */
	if (ctx->control != TLS_RECORD_TYPE_DATA) {
@@ -2027,12 +2037,6 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
		goto splice_read_end;
	}

		if (err < 0) {
			tls_err_abort(sk, -EBADMSG);
			goto splice_read_end;
		}
		ctx->decrypted = 1;
	}
	rxm = strp_msg(skb);

	chunk = min_t(unsigned int, rxm->full_len, len);
@@ -2040,7 +2044,17 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
	if (copied < 0)
		goto splice_read_end;

	tls_sw_advance_skb(sk, skb, copied);
	if (!from_queue) {
		ctx->recv_pkt = NULL;
		__strp_unpause(&ctx->strp);
	}
	if (chunk < rxm->full_len) {
		__skb_queue_head(&ctx->rx_list, skb);
		rxm->offset += len;
		rxm->full_len -= len;
	} else {
		consume_skb(skb);
	}

splice_read_end:
	release_sock(sk);
+389 −132
Original line number Diff line number Diff line
@@ -78,26 +78,21 @@ static void memrnd(void *s, size_t n)
		*byte++ = rand();
}

FIXTURE(tls_basic)
{
	int fd, cfd;
	bool notls;
};

FIXTURE_SETUP(tls_basic)
static void ulp_sock_pair(struct __test_metadata *_metadata,
			  int *fd, int *cfd, bool *notls)
{
	struct sockaddr_in addr;
	socklen_t len;
	int sfd, ret;

	self->notls = false;
	*notls = false;
	len = sizeof(addr);

	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = htonl(INADDR_ANY);
	addr.sin_port = 0;

	self->fd = socket(AF_INET, SOCK_STREAM, 0);
	*fd = socket(AF_INET, SOCK_STREAM, 0);
	sfd = socket(AF_INET, SOCK_STREAM, 0);

	ret = bind(sfd, &addr, sizeof(addr));
@@ -108,26 +103,96 @@ FIXTURE_SETUP(tls_basic)
	ret = getsockname(sfd, &addr, &len);
	ASSERT_EQ(ret, 0);

	ret = connect(self->fd, &addr, sizeof(addr));
	ret = connect(*fd, &addr, sizeof(addr));
	ASSERT_EQ(ret, 0);

	self->cfd = accept(sfd, &addr, &len);
	ASSERT_GE(self->cfd, 0);
	*cfd = accept(sfd, &addr, &len);
	ASSERT_GE(*cfd, 0);

	close(sfd);

	ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
	ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
	if (ret != 0) {
		ASSERT_EQ(errno, ENOENT);
		self->notls = true;
		*notls = true;
		printf("Failure setting TCP_ULP, testing without tls\n");
		return;
	}

	ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
	ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
	ASSERT_EQ(ret, 0);
}

/* Produce a basic cmsg */
static int tls_send_cmsg(int fd, unsigned char record_type,
			 void *data, size_t len, int flags)
{
	char cbuf[CMSG_SPACE(sizeof(char))];
	int cmsg_len = sizeof(char);
	struct cmsghdr *cmsg;
	struct msghdr msg;
	struct iovec vec;

	vec.iov_base = data;
	vec.iov_len = len;
	memset(&msg, 0, sizeof(struct msghdr));
	msg.msg_iov = &vec;
	msg.msg_iovlen = 1;
	msg.msg_control = cbuf;
	msg.msg_controllen = sizeof(cbuf);
	cmsg = CMSG_FIRSTHDR(&msg);
	cmsg->cmsg_level = SOL_TLS;
	/* test sending non-record types. */
	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
	*CMSG_DATA(cmsg) = record_type;
	msg.msg_controllen = cmsg->cmsg_len;

	return sendmsg(fd, &msg, flags);
}

static int tls_recv_cmsg(struct __test_metadata *_metadata,
			 int fd, unsigned char record_type,
			 void *data, size_t len, int flags)
{
	char cbuf[CMSG_SPACE(sizeof(char))];
	struct cmsghdr *cmsg;
	unsigned char ctype;
	struct msghdr msg;
	struct iovec vec;
	int n;

	vec.iov_base = data;
	vec.iov_len = len;
	memset(&msg, 0, sizeof(struct msghdr));
	msg.msg_iov = &vec;
	msg.msg_iovlen = 1;
	msg.msg_control = cbuf;
	msg.msg_controllen = sizeof(cbuf);

	n = recvmsg(fd, &msg, flags);

	cmsg = CMSG_FIRSTHDR(&msg);
	EXPECT_NE(cmsg, NULL);
	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
	ctype = *((unsigned char *)CMSG_DATA(cmsg));
	EXPECT_EQ(ctype, record_type);

	return n;
}

FIXTURE(tls_basic)
{
	int fd, cfd;
	bool notls;
};

FIXTURE_SETUP(tls_basic)
{
	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
}

FIXTURE_TEARDOWN(tls_basic)
{
	close(self->fd);
@@ -199,62 +264,23 @@ FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
FIXTURE_SETUP(tls)
{
	struct tls_crypto_info_keys tls12;
	struct sockaddr_in addr;
	socklen_t len;
	int sfd, ret;

	self->notls = false;
	len = sizeof(addr);
	int ret;

	tls_crypto_info_init(variant->tls_version, variant->cipher_type,
			     &tls12);

	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = htonl(INADDR_ANY);
	addr.sin_port = 0;

	self->fd = socket(AF_INET, SOCK_STREAM, 0);
	sfd = socket(AF_INET, SOCK_STREAM, 0);

	ret = bind(sfd, &addr, sizeof(addr));
	ASSERT_EQ(ret, 0);
	ret = listen(sfd, 10);
	ASSERT_EQ(ret, 0);

	ret = getsockname(sfd, &addr, &len);
	ASSERT_EQ(ret, 0);

	ret = connect(self->fd, &addr, sizeof(addr));
	ASSERT_EQ(ret, 0);
	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);

	ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
	if (ret != 0) {
		self->notls = true;
		printf("Failure setting TCP_ULP, testing without tls\n");
	}

	if (!self->notls) {
		ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12,
				 tls12.len);
		ASSERT_EQ(ret, 0);
	}

	self->cfd = accept(sfd, &addr, &len);
	ASSERT_GE(self->cfd, 0);
	if (self->notls)
		return;

	if (!self->notls) {
		ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls",
				 sizeof("tls"));
	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
	ASSERT_EQ(ret, 0);

		ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12,
				 tls12.len);
	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
	ASSERT_EQ(ret, 0);
}

	close(sfd);
}

FIXTURE_TEARDOWN(tls)
{
	close(self->fd);
@@ -613,6 +639,95 @@ TEST_F(tls, splice_to_pipe)
	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
}

TEST_F(tls, splice_cmsg_to_pipe)
{
	char *test_str = "test_read";
	char record_type = 100;
	int send_len = 10;
	char buf[10];
	int p[2];

	ASSERT_GE(pipe(p), 0);
	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
	EXPECT_EQ(errno, EINVAL);
	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
	EXPECT_EQ(errno, EIO);
	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
				buf, sizeof(buf), MSG_WAITALL),
		  send_len);
	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
}

TEST_F(tls, splice_dec_cmsg_to_pipe)
{
	char *test_str = "test_read";
	char record_type = 100;
	int send_len = 10;
	char buf[10];
	int p[2];

	ASSERT_GE(pipe(p), 0);
	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
	EXPECT_EQ(errno, EIO);
	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
	EXPECT_EQ(errno, EINVAL);
	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
				buf, sizeof(buf), MSG_WAITALL),
		  send_len);
	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
}

TEST_F(tls, recv_and_splice)
{
	int send_len = TLS_PAYLOAD_MAX_LEN;
	char mem_send[TLS_PAYLOAD_MAX_LEN];
	char mem_recv[TLS_PAYLOAD_MAX_LEN];
	int half = send_len / 2;
	int p[2];

	ASSERT_GE(pipe(p), 0);
	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
	/* Recv hald of the record, splice the other half */
	EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
		  half);
	EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
}

TEST_F(tls, peek_and_splice)
{
	int send_len = TLS_PAYLOAD_MAX_LEN;
	char mem_send[TLS_PAYLOAD_MAX_LEN];
	char mem_recv[TLS_PAYLOAD_MAX_LEN];
	int chunk = TLS_PAYLOAD_MAX_LEN / 4;
	int n, i, p[2];

	memrnd(mem_send, sizeof(mem_send));

	ASSERT_GE(pipe(p), 0);
	for (i = 0; i < 4; i++)
		EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
			  chunk);

	EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
		       MSG_WAITALL | MSG_PEEK),
		  chunk * 5 / 2);
	EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);

	n = 0;
	while (n < send_len) {
		i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
		EXPECT_GT(i, 0);
		n += i;
	}
	EXPECT_EQ(n, send_len);
	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
}

TEST_F(tls, recvmsg_single)
{
	char const *test_str = "test_recvmsg_single";
@@ -1193,60 +1308,30 @@ TEST_F(tls, mutliproc_sendpage_writers)

TEST_F(tls, control_msg)
{
	if (self->notls)
		return;

	char cbuf[CMSG_SPACE(sizeof(char))];
	char const *test_str = "test_read";
	int cmsg_len = sizeof(char);
	char *test_str = "test_read";
	char record_type = 100;
	struct cmsghdr *cmsg;
	struct msghdr msg;
	int send_len = 10;
	struct iovec vec;
	char buf[10];

	vec.iov_base = (char *)test_str;
	vec.iov_len = 10;
	memset(&msg, 0, sizeof(struct msghdr));
	msg.msg_iov = &vec;
	msg.msg_iovlen = 1;
	msg.msg_control = cbuf;
	msg.msg_controllen = sizeof(cbuf);
	cmsg = CMSG_FIRSTHDR(&msg);
	cmsg->cmsg_level = SOL_TLS;
	/* test sending non-record types. */
	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
	*CMSG_DATA(cmsg) = record_type;
	msg.msg_controllen = cmsg->cmsg_len;
	if (self->notls)
		SKIP(return, "no TLS support");

	EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
		  send_len);
	/* Should fail because we didn't provide a control message */
	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);

	vec.iov_base = buf;
	EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL | MSG_PEEK), send_len);

	cmsg = CMSG_FIRSTHDR(&msg);
	EXPECT_NE(cmsg, NULL);
	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
	record_type = *((unsigned char *)CMSG_DATA(cmsg));
	EXPECT_EQ(record_type, 100);
	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
				buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
		  send_len);
	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);

	/* Recv the message again without MSG_PEEK */
	record_type = 0;
	memset(buf, 0, sizeof(buf));

	EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL), send_len);
	cmsg = CMSG_FIRSTHDR(&msg);
	EXPECT_NE(cmsg, NULL);
	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
	record_type = *((unsigned char *)CMSG_DATA(cmsg));
	EXPECT_EQ(record_type, 100);
	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
				buf, sizeof(buf), MSG_WAITALL),
		  send_len);
	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
}

@@ -1301,6 +1386,160 @@ TEST_F(tls, shutdown_reuse)
	EXPECT_EQ(errno, EISCONN);
}

FIXTURE(tls_err)
{
	int fd, cfd;
	int fd2, cfd2;
	bool notls;
};

FIXTURE_VARIANT(tls_err)
{
	uint16_t tls_version;
};

FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
{
	.tls_version = TLS_1_2_VERSION,
};

FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
{
	.tls_version = TLS_1_3_VERSION,
};

FIXTURE_SETUP(tls_err)
{
	struct tls_crypto_info_keys tls12;
	int ret;

	tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
			     &tls12);

	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
	ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
	if (self->notls)
		return;

	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
	ASSERT_EQ(ret, 0);

	ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
	ASSERT_EQ(ret, 0);
}

FIXTURE_TEARDOWN(tls_err)
{
	close(self->fd);
	close(self->cfd);
	close(self->fd2);
	close(self->cfd2);
}

TEST_F(tls_err, bad_rec)
{
	char buf[64];

	if (self->notls)
		SKIP(return, "no TLS support");

	memset(buf, 0x55, sizeof(buf));
	EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
	EXPECT_EQ(errno, EMSGSIZE);
	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
	EXPECT_EQ(errno, EAGAIN);
}

TEST_F(tls_err, bad_auth)
{
	char buf[128];
	int n;

	if (self->notls)
		SKIP(return, "no TLS support");

	memrnd(buf, sizeof(buf) / 2);
	EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
	n = recv(self->cfd, buf, sizeof(buf), 0);
	EXPECT_GT(n, sizeof(buf) / 2);

	buf[n - 1]++;

	EXPECT_EQ(send(self->fd2, buf, n, 0), n);
	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
	EXPECT_EQ(errno, EBADMSG);
	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
	EXPECT_EQ(errno, EBADMSG);
}

TEST_F(tls_err, bad_in_large_read)
{
	char txt[3][64];
	char cip[3][128];
	char buf[3 * 128];
	int i, n;

	if (self->notls)
		SKIP(return, "no TLS support");

	/* Put 3 records in the sockets */
	for (i = 0; i < 3; i++) {
		memrnd(txt[i], sizeof(txt[i]));
		EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
			  sizeof(txt[i]));
		n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
		EXPECT_GT(n, sizeof(txt[i]));
		/* Break the third message */
		if (i == 2)
			cip[2][n - 1]++;
		EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
	}

	/* We should be able to receive the first two messages */
	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
	EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
	EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
	/* Third mesasge is bad */
	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
	EXPECT_EQ(errno, EBADMSG);
	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
	EXPECT_EQ(errno, EBADMSG);
}

TEST_F(tls_err, bad_cmsg)
{
	char *test_str = "test_read";
	int send_len = 10;
	char cip[128];
	char buf[128];
	char txt[64];
	int n;

	if (self->notls)
		SKIP(return, "no TLS support");

	/* Queue up one data record */
	memrnd(txt, sizeof(txt));
	EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
	n = recv(self->cfd, cip, sizeof(cip), 0);
	EXPECT_GT(n, sizeof(txt));
	EXPECT_EQ(send(self->fd2, cip, n, 0), n);

	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
	n = recv(self->cfd, cip, sizeof(cip), 0);
	cip[n - 1]++; /* Break it */
	EXPECT_GT(n, send_len);
	EXPECT_EQ(send(self->fd2, cip, n, 0), n);

	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
	EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
	EXPECT_EQ(errno, EBADMSG);
	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
	EXPECT_EQ(errno, EBADMSG);
}

TEST(non_established) {
	struct tls12_crypto_info_aes_gcm_256 tls12;
	struct sockaddr_in addr;
@@ -1355,64 +1594,82 @@ TEST(non_established) {

TEST(keysizes) {
	struct tls12_crypto_info_aes_gcm_256 tls12;
	struct sockaddr_in addr;
	int sfd, ret, fd, cfd;
	socklen_t len;
	int ret, fd, cfd;
	bool notls;

	notls = false;
	len = sizeof(addr);

	memset(&tls12, 0, sizeof(tls12));
	tls12.info.version = TLS_1_2_VERSION;
	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;

	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = htonl(INADDR_ANY);
	addr.sin_port = 0;
	ulp_sock_pair(_metadata, &fd, &cfd, &notls);

	fd = socket(AF_INET, SOCK_STREAM, 0);
	sfd = socket(AF_INET, SOCK_STREAM, 0);
	if (!notls) {
		ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
				 sizeof(tls12));
		EXPECT_EQ(ret, 0);

		ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
				 sizeof(tls12));
		EXPECT_EQ(ret, 0);
	}

	close(fd);
	close(cfd);
}

TEST(tls_v6ops) {
	struct tls_crypto_info_keys tls12;
	struct sockaddr_in6 addr, addr2;
	int sfd, ret, fd;
	socklen_t len, len2;

	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12);

	addr.sin6_family = AF_INET6;
	addr.sin6_addr = in6addr_any;
	addr.sin6_port = 0;

	fd = socket(AF_INET6, SOCK_STREAM, 0);
	sfd = socket(AF_INET6, SOCK_STREAM, 0);

	ret = bind(sfd, &addr, sizeof(addr));
	ASSERT_EQ(ret, 0);
	ret = listen(sfd, 10);
	ASSERT_EQ(ret, 0);

	len = sizeof(addr);
	ret = getsockname(sfd, &addr, &len);
	ASSERT_EQ(ret, 0);

	ret = connect(fd, &addr, sizeof(addr));
	ASSERT_EQ(ret, 0);

	len = sizeof(addr);
	ret = getsockname(fd, &addr, &len);
	ASSERT_EQ(ret, 0);

	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
	if (ret != 0) {
		notls = true;
		printf("Failure setting TCP_ULP, testing without tls\n");
	if (ret) {
		ASSERT_EQ(errno, ENOENT);
		SKIP(return, "no TLS support");
	}
	ASSERT_EQ(ret, 0);

	if (!notls) {
		ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
				 sizeof(tls12));
		EXPECT_EQ(ret, 0);
	}
	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
	ASSERT_EQ(ret, 0);

	cfd = accept(sfd, &addr, &len);
	ASSERT_GE(cfd, 0);
	ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
	ASSERT_EQ(ret, 0);

	if (!notls) {
		ret = setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls",
				 sizeof("tls"));
		EXPECT_EQ(ret, 0);
	len2 = sizeof(addr2);
	ret = getsockname(fd, &addr2, &len2);
	ASSERT_EQ(ret, 0);

		ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
				 sizeof(tls12));
		EXPECT_EQ(ret, 0);
	}
	EXPECT_EQ(len2, len);
	EXPECT_EQ(memcmp(&addr, &addr2, len), 0);

	close(sfd);
	close(fd);
	close(cfd);
	close(sfd);
}

TEST_HARNESS_MAIN