Commit 52de1fe1 authored by Jens Axboe's avatar Jens Axboe
Browse files

io_uring: add IOSQE_BUFFER_SELECT support for IORING_OP_RECVMSG



Like IORING_OP_READV, this is limited to supporting just a single
segment in the iovec passed in.

Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 0a384abf
Loading
Loading
Loading
Loading
+106 −12
Original line number Diff line number Diff line
@@ -44,6 +44,7 @@
#include <linux/errno.h>
#include <linux/syscalls.h>
#include <linux/compat.h>
#include <net/compat.h>
#include <linux/refcount.h>
#include <linux/uio.h>
#include <linux/bits.h>
@@ -729,6 +730,7 @@ static const struct io_op_def io_op_defs[] = {
		.unbound_nonreg_file	= 1,
		.needs_fs		= 1,
		.pollin			= 1,
		.buffer_select		= 1,
	},
	[IORING_OP_TIMEOUT] = {
		.async_ctx		= 1,
@@ -3569,6 +3571,92 @@ static int io_send(struct io_kiocb *req, bool force_nonblock)
#endif
}

static int __io_recvmsg_copy_hdr(struct io_kiocb *req, struct io_async_ctx *io)
{
	struct io_sr_msg *sr = &req->sr_msg;
	struct iovec __user *uiov;
	size_t iov_len;
	int ret;

	ret = __copy_msghdr_from_user(&io->msg.msg, sr->msg, &io->msg.uaddr,
					&uiov, &iov_len);
	if (ret)
		return ret;

	if (req->flags & REQ_F_BUFFER_SELECT) {
		if (iov_len > 1)
			return -EINVAL;
		if (copy_from_user(io->msg.iov, uiov, sizeof(*uiov)))
			return -EFAULT;
		sr->len = io->msg.iov[0].iov_len;
		iov_iter_init(&io->msg.msg.msg_iter, READ, io->msg.iov, 1,
				sr->len);
		io->msg.iov = NULL;
	} else {
		ret = import_iovec(READ, uiov, iov_len, UIO_FASTIOV,
					&io->msg.iov, &io->msg.msg.msg_iter);
		if (ret > 0)
			ret = 0;
	}

	return ret;
}

#ifdef CONFIG_COMPAT
static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
					struct io_async_ctx *io)
{
	struct compat_msghdr __user *msg_compat;
	struct io_sr_msg *sr = &req->sr_msg;
	struct compat_iovec __user *uiov;
	compat_uptr_t ptr;
	compat_size_t len;
	int ret;

	msg_compat = (struct compat_msghdr __user *) sr->msg;
	ret = __get_compat_msghdr(&io->msg.msg, msg_compat, &io->msg.uaddr,
					&ptr, &len);
	if (ret)
		return ret;

	uiov = compat_ptr(ptr);
	if (req->flags & REQ_F_BUFFER_SELECT) {
		compat_ssize_t clen;

		if (len > 1)
			return -EINVAL;
		if (!access_ok(uiov, sizeof(*uiov)))
			return -EFAULT;
		if (__get_user(clen, &uiov->iov_len))
			return -EFAULT;
		if (clen < 0)
			return -EINVAL;
		sr->len = io->msg.iov[0].iov_len;
		io->msg.iov = NULL;
	} else {
		ret = compat_import_iovec(READ, uiov, len, UIO_FASTIOV,
						&io->msg.iov,
						&io->msg.msg.msg_iter);
		if (ret < 0)
			return ret;
	}

	return 0;
}
#endif

static int io_recvmsg_copy_hdr(struct io_kiocb *req, struct io_async_ctx *io)
{
	io->msg.iov = io->msg.fast_iov;

#ifdef CONFIG_COMPAT
	if (req->ctx->compat)
		return __io_compat_recvmsg_copy_hdr(req, io);
#endif

	return __io_recvmsg_copy_hdr(req, io);
}

static struct io_buffer *io_recv_buffer_select(struct io_kiocb *req,
					       int *cflags, bool needs_lock)
{
@@ -3614,9 +3702,7 @@ static int io_recvmsg_prep(struct io_kiocb *req,
	if (req->flags & REQ_F_NEED_CLEANUP)
		return 0;

	io->msg.iov = io->msg.fast_iov;
	ret = recvmsg_copy_msghdr(&io->msg.msg, sr->msg, sr->msg_flags,
					&io->msg.uaddr, &io->msg.iov);
	ret = io_recvmsg_copy_hdr(req, io);
	if (!ret)
		req->flags |= REQ_F_NEED_CLEANUP;
	return ret;
@@ -3630,13 +3716,14 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock)
#if defined(CONFIG_NET)
	struct io_async_msghdr *kmsg = NULL;
	struct socket *sock;
	int ret;
	int ret, cflags = 0;

	if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
		return -EINVAL;

	sock = sock_from_file(req->file, &ret);
	if (sock) {
		struct io_buffer *kbuf;
		struct io_async_ctx io;
		unsigned flags;

@@ -3648,19 +3735,23 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock)
				kmsg->iov = kmsg->fast_iov;
			kmsg->msg.msg_iter.iov = kmsg->iov;
		} else {
			struct io_sr_msg *sr = &req->sr_msg;

			kmsg = &io.msg;
			kmsg->msg.msg_name = &io.msg.addr;

			io.msg.iov = io.msg.fast_iov;
			ret = recvmsg_copy_msghdr(&io.msg.msg, sr->msg,
					sr->msg_flags, &io.msg.uaddr,
					&io.msg.iov);
			ret = io_recvmsg_copy_hdr(req, &io);
			if (ret)
				return ret;
		}

		kbuf = io_recv_buffer_select(req, &cflags, !force_nonblock);
		if (IS_ERR(kbuf)) {
			return PTR_ERR(kbuf);
		} else if (kbuf) {
			kmsg->fast_iov[0].iov_base = u64_to_user_ptr(kbuf->addr);
			iov_iter_init(&kmsg->msg.msg_iter, READ, kmsg->iov,
					1, req->sr_msg.len);
		}

		flags = req->sr_msg.msg_flags;
		if (flags & MSG_DONTWAIT)
			req->flags |= REQ_F_NOWAIT;
@@ -3678,7 +3769,7 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock)
	if (kmsg && kmsg->iov != kmsg->fast_iov)
		kfree(kmsg->iov);
	req->flags &= ~REQ_F_NEED_CLEANUP;
	io_cqring_add_event(req, ret);
	__io_cqring_add_event(req, ret, cflags);
	if (ret < 0)
		req_set_fail_links(req);
	io_put_req(req);
@@ -4789,8 +4880,11 @@ static void io_cleanup_req(struct io_kiocb *req)
		if (io->rw.iov != io->rw.fast_iov)
			kfree(io->rw.iov);
		break;
	case IORING_OP_SENDMSG:
	case IORING_OP_RECVMSG:
		if (req->flags & REQ_F_BUFFER_SELECTED)
			kfree(req->sr_msg.kbuf);
		/* fallthrough */
	case IORING_OP_SENDMSG:
		if (io->msg.iov != io->msg.fast_iov)
			kfree(io->msg.iov);
		break;