Commit c320e527 authored by Moni Shoua's avatar Moni Shoua Committed by Leon Romanovsky
Browse files

IB: Allow calls to ib_umem_get from kernel ULPs



So far the assumption was that ib_umem_get() and ib_umem_odp_get()
are called from flows that start in UVERBS and therefore has a user
context. This assumption restricts flows that are initiated by ULPs
and need the service that ib_umem_get() provides.

This patch changes ib_umem_get() and ib_umem_odp_get() to get IB device
directly by relying on the fact that both UVERBS and ULPs sets that
field correctly.

Reviewed-by: default avatarGuy Levi <guyle@mellanox.com>
Signed-off-by: default avatarMoni Shoua <monis@mellanox.com>
Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
parent b3a987b0
Loading
Loading
Loading
Loading
+9 −18
Original line number Diff line number Diff line
@@ -181,15 +181,14 @@ EXPORT_SYMBOL(ib_umem_find_best_pgsz);
/**
 * ib_umem_get - Pin and DMA map userspace memory.
 *
 * @udata: userspace context to pin memory for
 * @device: IB device to connect UMEM
 * @addr: userspace virtual address to start at
 * @size: length of region to pin
 * @access: IB_ACCESS_xxx flags for memory being pinned
 */
struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
struct ib_umem *ib_umem_get(struct ib_device *device, unsigned long addr,
			    size_t size, int access)
{
	struct ib_ucontext *context;
	struct ib_umem *umem;
	struct page **page_list;
	unsigned long lock_limit;
@@ -201,14 +200,6 @@ struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
	struct scatterlist *sg;
	unsigned int gup_flags = FOLL_WRITE;

	if (!udata)
		return ERR_PTR(-EIO);

	context = container_of(udata, struct uverbs_attr_bundle, driver_udata)
			  ->context;
	if (!context)
		return ERR_PTR(-EIO);

	/*
	 * If the combination of the addr and size requested for this memory
	 * region causes an integer overflow, return error.
@@ -226,7 +217,7 @@ struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
	umem = kzalloc(sizeof(*umem), GFP_KERNEL);
	if (!umem)
		return ERR_PTR(-ENOMEM);
	umem->ibdev = context->device;
	umem->ibdev      = device;
	umem->length     = size;
	umem->address    = addr;
	umem->writable   = ib_access_writable(access);
@@ -281,7 +272,7 @@ struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
		npages   -= ret;

		sg = ib_umem_add_sg_table(sg, page_list, ret,
			dma_get_max_seg_size(context->device->dma_device),
			dma_get_max_seg_size(device->dma_device),
			&umem->sg_nents);

		up_read(&mm->mmap_sem);
@@ -289,7 +280,7 @@ struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,

	sg_mark_end(sg);

	umem->nmap = ib_dma_map_sg(context->device,
	umem->nmap = ib_dma_map_sg(device,
				   umem->sg_head.sgl,
				   umem->sg_nents,
				   DMA_BIDIRECTIONAL);
@@ -303,7 +294,7 @@ struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
	goto out;

umem_release:
	__ib_umem_release(context->device, umem, 0);
	__ib_umem_release(device, umem, 0);
vma:
	atomic64_sub(ib_umem_num_pages(umem), &mm->pinned_vm);
out:
+7 −22
Original line number Diff line number Diff line
@@ -110,15 +110,12 @@ static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
 * They exist only to hold the per_mm reference to help the driver create
 * children umems.
 *
 * @udata: udata from the syscall being used to create the umem
 * @device: IB device to create UMEM
 * @access: ib_reg_mr access flags
 */
struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_device *device,
					       int access)
{
	struct ib_ucontext *context =
		container_of(udata, struct uverbs_attr_bundle, driver_udata)
			->context;
	struct ib_umem *umem;
	struct ib_umem_odp *umem_odp;
	int ret;
@@ -126,14 +123,11 @@ struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
	if (access & IB_ACCESS_HUGETLB)
		return ERR_PTR(-EINVAL);

	if (!context)
		return ERR_PTR(-EIO);

	umem_odp = kzalloc(sizeof(*umem_odp), GFP_KERNEL);
	if (!umem_odp)
		return ERR_PTR(-ENOMEM);
	umem = &umem_odp->umem;
	umem->ibdev = context->device;
	umem->ibdev = device;
	umem->writable = ib_access_writable(access);
	umem->owning_mm = current->mm;
	umem_odp->is_implicit_odp = 1;
@@ -201,7 +195,7 @@ EXPORT_SYMBOL(ib_umem_odp_alloc_child);
/**
 * ib_umem_odp_get - Create a umem_odp for a userspace va
 *
 * @udata: userspace context to pin memory for
 * @device: IB device struct to get UMEM
 * @addr: userspace virtual address to start at
 * @size: length of region to pin
 * @access: IB_ACCESS_xxx flags for memory being pinned
@@ -210,23 +204,14 @@ EXPORT_SYMBOL(ib_umem_odp_alloc_child);
 * pinning, instead, stores the mm for future page fault handling in
 * conjunction with MMU notifiers.
 */
struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
				    size_t size, int access,
struct ib_umem_odp *ib_umem_odp_get(struct ib_device *device,
				    unsigned long addr, size_t size, int access,
				    const struct mmu_interval_notifier_ops *ops)
{
	struct ib_umem_odp *umem_odp;
	struct ib_ucontext *context;
	struct mm_struct *mm;
	int ret;

	if (!udata)
		return ERR_PTR(-EIO);

	context = container_of(udata, struct uverbs_attr_bundle, driver_udata)
			  ->context;
	if (!context)
		return ERR_PTR(-EIO);

	if (WARN_ON_ONCE(!(access & IB_ACCESS_ON_DEMAND)))
		return ERR_PTR(-EINVAL);

@@ -234,7 +219,7 @@ struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
	if (!umem_odp)
		return ERR_PTR(-ENOMEM);

	umem_odp->umem.ibdev = context->device;
	umem_odp->umem.ibdev = device;
	umem_odp->umem.length = size;
	umem_odp->umem.address = addr;
	umem_odp->umem.writable = ib_access_writable(access);
+7 −5
Original line number Diff line number Diff line
@@ -837,7 +837,8 @@ static int bnxt_re_init_user_qp(struct bnxt_re_dev *rdev, struct bnxt_re_pd *pd,
		bytes += (qplib_qp->sq.max_wqe * psn_sz);
	}
	bytes = PAGE_ALIGN(bytes);
	umem = ib_umem_get(udata, ureq.qpsva, bytes, IB_ACCESS_LOCAL_WRITE);
	umem = ib_umem_get(&rdev->ibdev, ureq.qpsva, bytes,
			   IB_ACCESS_LOCAL_WRITE);
	if (IS_ERR(umem))
		return PTR_ERR(umem);

@@ -850,7 +851,7 @@ static int bnxt_re_init_user_qp(struct bnxt_re_dev *rdev, struct bnxt_re_pd *pd,
	if (!qp->qplib_qp.srq) {
		bytes = (qplib_qp->rq.max_wqe * BNXT_QPLIB_MAX_RQE_ENTRY_SIZE);
		bytes = PAGE_ALIGN(bytes);
		umem = ib_umem_get(udata, ureq.qprva, bytes,
		umem = ib_umem_get(&rdev->ibdev, ureq.qprva, bytes,
				   IB_ACCESS_LOCAL_WRITE);
		if (IS_ERR(umem))
			goto rqfail;
@@ -1304,7 +1305,8 @@ static int bnxt_re_init_user_srq(struct bnxt_re_dev *rdev,

	bytes = (qplib_srq->max_wqe * BNXT_QPLIB_MAX_RQE_ENTRY_SIZE);
	bytes = PAGE_ALIGN(bytes);
	umem = ib_umem_get(udata, ureq.srqva, bytes, IB_ACCESS_LOCAL_WRITE);
	umem = ib_umem_get(&rdev->ibdev, ureq.srqva, bytes,
			   IB_ACCESS_LOCAL_WRITE);
	if (IS_ERR(umem))
		return PTR_ERR(umem);

@@ -2545,7 +2547,7 @@ int bnxt_re_create_cq(struct ib_cq *ibcq, const struct ib_cq_init_attr *attr,
			goto fail;
		}

		cq->umem = ib_umem_get(udata, req.cq_va,
		cq->umem = ib_umem_get(&rdev->ibdev, req.cq_va,
				       entries * sizeof(struct cq_base),
				       IB_ACCESS_LOCAL_WRITE);
		if (IS_ERR(cq->umem)) {
@@ -3514,7 +3516,7 @@ struct ib_mr *bnxt_re_reg_user_mr(struct ib_pd *ib_pd, u64 start, u64 length,
	/* The fixed portion of the rkey is the same as the lkey */
	mr->ib_mr.rkey = mr->qplib_mr.rkey;

	umem = ib_umem_get(udata, start, length, mr_access_flags);
	umem = ib_umem_get(&rdev->ibdev, start, length, mr_access_flags);
	if (IS_ERR(umem)) {
		dev_err(rdev_to_dev(rdev), "Failed to get umem");
		rc = -EFAULT;
+1 −1
Original line number Diff line number Diff line
@@ -543,7 +543,7 @@ struct ib_mr *c4iw_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,

	mhp->rhp = rhp;

	mhp->umem = ib_umem_get(udata, start, length, acc);
	mhp->umem = ib_umem_get(pd->device, start, length, acc);
	if (IS_ERR(mhp->umem))
		goto err_free_skb;

+1 −1
Original line number Diff line number Diff line
@@ -1384,7 +1384,7 @@ struct ib_mr *efa_reg_mr(struct ib_pd *ibpd, u64 start, u64 length,
		goto err_out;
	}

	mr->umem = ib_umem_get(udata, start, length, access_flags);
	mr->umem = ib_umem_get(ibpd->device, start, length, access_flags);
	if (IS_ERR(mr->umem)) {
		err = PTR_ERR(mr->umem);
		ibdev_dbg(&dev->ibdev,
Loading