Commit 31180adb authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

selftests: tls: factor out cmsg send/receive



Add helpers for sending and receiving special record types.

Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent a125f91f
Loading
Loading
Loading
Loading
+70 −41
Original line number Diff line number Diff line
@@ -123,6 +123,65 @@ static void ulp_sock_pair(struct __test_metadata *_metadata,
	ASSERT_EQ(ret, 0);
}

/* Produce a basic cmsg */
static int tls_send_cmsg(int fd, unsigned char record_type,
			 void *data, size_t len, int flags)
{
	char cbuf[CMSG_SPACE(sizeof(char))];
	int cmsg_len = sizeof(char);
	struct cmsghdr *cmsg;
	struct msghdr msg;
	struct iovec vec;

	vec.iov_base = data;
	vec.iov_len = len;
	memset(&msg, 0, sizeof(struct msghdr));
	msg.msg_iov = &vec;
	msg.msg_iovlen = 1;
	msg.msg_control = cbuf;
	msg.msg_controllen = sizeof(cbuf);
	cmsg = CMSG_FIRSTHDR(&msg);
	cmsg->cmsg_level = SOL_TLS;
	/* test sending non-record types. */
	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
	*CMSG_DATA(cmsg) = record_type;
	msg.msg_controllen = cmsg->cmsg_len;

	return sendmsg(fd, &msg, flags);
}

static int tls_recv_cmsg(struct __test_metadata *_metadata,
			 int fd, unsigned char record_type,
			 void *data, size_t len, int flags)
{
	char cbuf[CMSG_SPACE(sizeof(char))];
	struct cmsghdr *cmsg;
	unsigned char ctype;
	struct msghdr msg;
	struct iovec vec;
	int n;

	vec.iov_base = data;
	vec.iov_len = len;
	memset(&msg, 0, sizeof(struct msghdr));
	msg.msg_iov = &vec;
	msg.msg_iovlen = 1;
	msg.msg_control = cbuf;
	msg.msg_controllen = sizeof(cbuf);

	n = recvmsg(fd, &msg, flags);

	cmsg = CMSG_FIRSTHDR(&msg);
	EXPECT_NE(cmsg, NULL);
	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
	ctype = *((unsigned char *)CMSG_DATA(cmsg));
	EXPECT_EQ(ctype, record_type);

	return n;
}

FIXTURE(tls_basic)
{
	int fd, cfd;
@@ -1160,60 +1219,30 @@ TEST_F(tls, mutliproc_sendpage_writers)

TEST_F(tls, control_msg)
{
	if (self->notls)
		return;

	char cbuf[CMSG_SPACE(sizeof(char))];
	char const *test_str = "test_read";
	int cmsg_len = sizeof(char);
	char *test_str = "test_read";
	char record_type = 100;
	struct cmsghdr *cmsg;
	struct msghdr msg;
	int send_len = 10;
	struct iovec vec;
	char buf[10];

	vec.iov_base = (char *)test_str;
	vec.iov_len = 10;
	memset(&msg, 0, sizeof(struct msghdr));
	msg.msg_iov = &vec;
	msg.msg_iovlen = 1;
	msg.msg_control = cbuf;
	msg.msg_controllen = sizeof(cbuf);
	cmsg = CMSG_FIRSTHDR(&msg);
	cmsg->cmsg_level = SOL_TLS;
	/* test sending non-record types. */
	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
	*CMSG_DATA(cmsg) = record_type;
	msg.msg_controllen = cmsg->cmsg_len;
	if (self->notls)
		SKIP(return, "no TLS support");

	EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
		  send_len);
	/* Should fail because we didn't provide a control message */
	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);

	vec.iov_base = buf;
	EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL | MSG_PEEK), send_len);

	cmsg = CMSG_FIRSTHDR(&msg);
	EXPECT_NE(cmsg, NULL);
	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
	record_type = *((unsigned char *)CMSG_DATA(cmsg));
	EXPECT_EQ(record_type, 100);
	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
				buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
		  send_len);
	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);

	/* Recv the message again without MSG_PEEK */
	record_type = 0;
	memset(buf, 0, sizeof(buf));

	EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL), send_len);
	cmsg = CMSG_FIRSTHDR(&msg);
	EXPECT_NE(cmsg, NULL);
	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
	record_type = *((unsigned char *)CMSG_DATA(cmsg));
	EXPECT_EQ(record_type, 100);
	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
				buf, sizeof(buf), MSG_WAITALL),
		  send_len);
	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
}