Commit 7d723065 authored by Jens Axboe's avatar Jens Axboe
Browse files

io_wq: add get/put_work handlers to io_wq_create()



For cancellation, we need to ensure that the work item stays valid for
as long as ->cur_work is valid. Right now we can't safely dereference
the work item even under the wqe->lock, because while the ->cur_work
pointer will remain valid, the work could be completing and be freed
in parallel.

Only invoke ->get/put_work() on items we know that the caller queued
themselves. Add IO_WQ_WORK_INTERNAL for io-wq to use, which is needed
when we're queueing a flush item, for instance.

Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 15dff286
Loading
Loading
Loading
Loading
+23 −2
Original line number Diff line number Diff line
@@ -106,6 +106,9 @@ struct io_wq {
	unsigned long state;
	unsigned nr_wqes;

	get_work_fn *get_work;
	put_work_fn *put_work;

	struct task_struct *manager;
	struct user_struct *user;
	struct mm_struct *mm;
@@ -392,7 +395,7 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
static void io_worker_handle_work(struct io_worker *worker)
	__releases(wqe->lock)
{
	struct io_wq_work *work, *old_work;
	struct io_wq_work *work, *old_work = NULL, *put_work = NULL;
	struct io_wqe *wqe = worker->wqe;
	struct io_wq *wq = wqe->wq;

@@ -424,6 +427,8 @@ static void io_worker_handle_work(struct io_worker *worker)
			wqe->flags |= IO_WQE_FLAG_STALLED;

		spin_unlock_irq(&wqe->lock);
		if (put_work && wq->put_work)
			wq->put_work(old_work);
		if (!work)
			break;
next:
@@ -444,6 +449,11 @@ static void io_worker_handle_work(struct io_worker *worker)
		if (worker->mm)
			work->flags |= IO_WQ_WORK_HAS_MM;

		if (wq->get_work && !(work->flags & IO_WQ_WORK_INTERNAL)) {
			put_work = work;
			wq->get_work(work);
		}

		old_work = work;
		work->func(&work);

@@ -455,6 +465,12 @@ static void io_worker_handle_work(struct io_worker *worker)
		}
		if (work && work != old_work) {
			spin_unlock_irq(&wqe->lock);

			if (put_work && wq->put_work) {
				wq->put_work(put_work);
				put_work = NULL;
			}

			/* dependent work not hashed */
			hash = -1U;
			goto next;
@@ -950,13 +966,15 @@ void io_wq_flush(struct io_wq *wq)

		init_completion(&data.done);
		INIT_IO_WORK(&data.work, io_wq_flush_func);
		data.work.flags |= IO_WQ_WORK_INTERNAL;
		io_wqe_enqueue(wqe, &data.work);
		wait_for_completion(&data.done);
	}
}

struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
			   struct user_struct *user)
			   struct user_struct *user, get_work_fn *get_work,
			   put_work_fn *put_work)
{
	int ret = -ENOMEM, i, node;
	struct io_wq *wq;
@@ -972,6 +990,9 @@ struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
		return ERR_PTR(-ENOMEM);
	}

	wq->get_work = get_work;
	wq->put_work = put_work;

	/* caller must already hold a reference to this */
	wq->user = user;

+6 −1
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ enum {
	IO_WQ_WORK_NEEDS_USER	= 8,
	IO_WQ_WORK_NEEDS_FILES	= 16,
	IO_WQ_WORK_UNBOUND	= 32,
	IO_WQ_WORK_INTERNAL	= 64,

	IO_WQ_HASH_SHIFT	= 24,	/* upper 8 bits are used for hash key */
};
@@ -34,8 +35,12 @@ struct io_wq_work {
		(work)->files = NULL;			\
	} while (0)					\

typedef void (get_work_fn)(struct io_wq_work *);
typedef void (put_work_fn)(struct io_wq_work *);

struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
				struct user_struct *user);
				struct user_struct *user,
				get_work_fn *get_work, put_work_fn *put_work);
void io_wq_destroy(struct io_wq *wq);

void io_wq_enqueue(struct io_wq *wq, struct io_wq_work *work);
+16 −1
Original line number Diff line number Diff line
@@ -3822,6 +3822,20 @@ static int io_sqe_files_update(struct io_ring_ctx *ctx, void __user *arg,
	return done ? done : err;
}

static void io_put_work(struct io_wq_work *work)
{
	struct io_kiocb *req = container_of(work, struct io_kiocb, work);

	io_put_req(req);
}

static void io_get_work(struct io_wq_work *work)
{
	struct io_kiocb *req = container_of(work, struct io_kiocb, work);

	refcount_inc(&req->refs);
}

static int io_sq_offload_start(struct io_ring_ctx *ctx,
			       struct io_uring_params *p)
{
@@ -3871,7 +3885,8 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,

	/* Do QD, or 4 * CPUS, whatever is smallest */
	concurrency = min(ctx->sq_entries, 4 * num_online_cpus());
	ctx->io_wq = io_wq_create(concurrency, ctx->sqo_mm, ctx->user);
	ctx->io_wq = io_wq_create(concurrency, ctx->sqo_mm, ctx->user,
					io_get_work, io_put_work);
	if (IS_ERR(ctx->io_wq)) {
		ret = PTR_ERR(ctx->io_wq);
		ctx->io_wq = NULL;