Commit 49ea02d3 authored by Yi Liu's avatar Yi Liu Committed by Jason Gunthorpe
Browse files

vfio: Set device->group in helper function

This avoids referencing device->group in __vfio_register_dev().

Link: https://lore.kernel.org/r/20221201145535.589687-5-yi.l.liu@intel.com


Reviewed-by: default avatarJason Gunthorpe <jgg@nvidia.com>
Reviewed-by: default avatarKevin Tian <kevin.tian@intel.com>
Reviewed-by: default avatarAlex Williamson <alex.williamson@redhat.com>
Tested-by: default avatarLixiao Yang <lixiao.yang@intel.com>
Tested-by: default avatarYu He <yu.he@intel.com>
Signed-off-by: default avatarYi Liu <yi.l.liu@intel.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@nvidia.com>
parent 32e09228
Loading
Loading
Loading
Loading
+26 −15
Original line number Diff line number Diff line
@@ -528,18 +528,29 @@ static void vfio_device_group_unregister(struct vfio_device *device)
	mutex_unlock(&device->group->device_lock);
}

static int __vfio_register_dev(struct vfio_device *device,
		struct vfio_group *group)
static int vfio_device_set_group(struct vfio_device *device,
				 enum vfio_group_type type)
{
	int ret;
	struct vfio_group *group;

	if (type == VFIO_IOMMU)
		group = vfio_group_find_or_alloc(device->dev);
	else
		group = vfio_noiommu_group_alloc(device->dev, type);

	/*
	 * In all cases group is the output of one of the group allocation
	 * functions and we have group->drivers incremented for us.
	 */
	if (IS_ERR(group))
		return PTR_ERR(group);

	/* Our reference on group is moved to the device */
	device->group = group;
	return 0;
}

static int __vfio_register_dev(struct vfio_device *device,
			       enum vfio_group_type type)
{
	int ret;

	if (WARN_ON(device->ops->bind_iommufd &&
		    (!device->ops->unbind_iommufd ||
		     !device->ops->attach_ioas)))
@@ -552,12 +563,13 @@ static int __vfio_register_dev(struct vfio_device *device,
	if (!device->dev_set)
		vfio_assign_device_set(device, device);

	/* Our reference on group is moved to the device */
	device->group = group;

	ret = dev_set_name(&device->device, "vfio%d", device->index);
	if (ret)
		goto err_out;
		return ret;

	ret = vfio_device_set_group(device, type);
	if (ret)
		return ret;

	ret = device_add(&device->device);
	if (ret)
@@ -576,8 +588,7 @@ static int __vfio_register_dev(struct vfio_device *device,

int vfio_register_group_dev(struct vfio_device *device)
{
	return __vfio_register_dev(device,
		vfio_group_find_or_alloc(device->dev));
	return __vfio_register_dev(device, VFIO_IOMMU);
}
EXPORT_SYMBOL_GPL(vfio_register_group_dev);

@@ -587,8 +598,7 @@ EXPORT_SYMBOL_GPL(vfio_register_group_dev);
 */
int vfio_register_emulated_iommu_dev(struct vfio_device *device)
{
	return __vfio_register_dev(device,
		vfio_noiommu_group_alloc(device->dev, VFIO_EMULATED_IOMMU));
	return __vfio_register_dev(device, VFIO_EMULATED_IOMMU);
}
EXPORT_SYMBOL_GPL(vfio_register_emulated_iommu_dev);

@@ -658,6 +668,7 @@ void vfio_unregister_group_dev(struct vfio_device *device)
	/* Balances device_add in register path */
	device_del(&device->device);

	/* Balances vfio_device_set_group in register path */
	vfio_device_remove_group(device);
}
EXPORT_SYMBOL_GPL(vfio_unregister_group_dev);