Commit 3ce4cb81 authored by xiongmengbiao's avatar xiongmengbiao
Browse files

drivers/crypto/ccp: support TKM run on CSV

hygon inclusion
category: feature
bugzilla: https://gitee.com/openeuler/kernel/issues/IAMHXY


CVE: NA

---------------------------

The CSV virtual machine puts the tkm command data
into an encrypted 2MB hugepage space,
then submits the command address GPA to the host kernel,
then the host kernel converts the GPA to HPA,
and finally submits the HPA to the PSP hardware.

During the entire command forwarding process,
since the data of the CSV virtual machine is encrypted,
the host kernel cannot access the real TKM command data submitted by CSV.

Signed-off-by: default avatarxiongmengbiao <xiongmengbiao@hygon.cn>
parent 0195654d
Loading
Loading
Loading
Loading
+11 −2
Original line number Diff line number Diff line
@@ -5215,13 +5215,22 @@ static int kvm_hygon_arch_hypercall(struct kvm *kvm, u64 nr, u64 a0, u64 a1, u64
	struct kvm_vpsp vpsp = {
		.kvm = kvm,
		.write_guest = kvm_write_guest,
		.read_guest = kvm_read_guest
		.read_guest = kvm_read_guest,
		.gfn_to_pfn = gfn_to_pfn,
	};

	if (sev_guest(kvm)) {
		vpsp.vm_handle = to_kvm_svm(kvm)->sev_info.handle;
		vpsp.is_csv_guest = 1;
	}

	switch (nr) {
	case KVM_HC_PSP_COPY_FORWARD_OP:
		ret = kvm_pv_psp_copy_forward_op(&vpsp, a0, a1, a2);
		break;

	case KVM_HC_PSP_FORWARD_OP:
		ret = kvm_pv_psp_forward_op(&vpsp, a0, a1, a2);
		break;
	default:
		ret = -KVM_ENOSYS;
		break;
+3 −1
Original line number Diff line number Diff line
@@ -9974,7 +9974,8 @@ int kvm_emulate_hypercall(struct kvm_vcpu *vcpu)
	if (static_call(kvm_x86_get_cpl)(vcpu) != 0 &&
	    !(is_x86_vendor_hygon() && (nr == KVM_HC_VM_ATTESTATION
					|| nr == KVM_HC_PSP_OP_OBSOLETE
					|| nr == KVM_HC_PSP_COPY_FORWARD_OP))) {
					|| nr == KVM_HC_PSP_COPY_FORWARD_OP
					|| nr == KVM_HC_PSP_FORWARD_OP))) {
		ret = -KVM_EPERM;
		goto out;
	}
@@ -10013,6 +10014,7 @@ int kvm_emulate_hypercall(struct kvm_vcpu *vcpu)
		break;
	case KVM_HC_PSP_OP_OBSOLETE:
	case KVM_HC_PSP_COPY_FORWARD_OP:
	case KVM_HC_PSP_FORWARD_OP:
		ret = -KVM_ENOSYS;
		if (kvm_arch_hypercall)
			ret = kvm_arch_hypercall(vcpu->kvm, nr, a0, a1, a2, a3);
+14 −26
Original line number Diff line number Diff line
@@ -646,12 +646,12 @@ static int vpsp_dequeue_cmd(int prio, int index,
 * Populate the command from the virtual machine to the queue to
 * support execution in ringbuffer mode
 */
static int vpsp_fill_cmd_queue(uint32_t vid, int prio, int cmd, void *data, uint16_t flags)
static int vpsp_fill_cmd_queue(int prio, int cmd, phys_addr_t phy_addr, uint16_t flags)
{
	struct csv_cmdptr_entry cmdptr = { };
	int index = -1;

	cmdptr.cmd_buf_ptr = PUT_PSP_VID(__psp_pa(data), vid);
	cmdptr.cmd_buf_ptr = phy_addr;
	cmdptr.cmd_id = cmd;
	cmdptr.cmd_flags = flags;

@@ -939,11 +939,10 @@ static int vpsp_rb_check_and_cmd_prio_parse(uint8_t *prio,
	return rb_supported;
}

int __vpsp_do_cmd_locked(uint32_t vid, int cmd, void *data, int *psp_ret)
static int __vpsp_do_cmd_locked(int cmd, phys_addr_t phy_addr, int *psp_ret)
{
	struct psp_device *psp = psp_master;
	struct sev_device *sev;
	phys_addr_t phys_addr;
	unsigned int phys_lsb, phys_msb;
	unsigned int reg, ret = 0;

@@ -955,20 +954,13 @@ int __vpsp_do_cmd_locked(uint32_t vid, int cmd, void *data, int *psp_ret)

	sev = psp->sev_data;

	if (data && WARN_ON_ONCE(!virt_addr_valid(data)))
		return -EINVAL;

	/* Get the physical address of the command buffer */
	phys_addr = PUT_PSP_VID(__psp_pa(data), vid);
	phys_lsb = data ? lower_32_bits(phys_addr) : 0;
	phys_msb = data ? upper_32_bits(phys_addr) : 0;
	phys_lsb = phy_addr ? lower_32_bits(phy_addr) : 0;
	phys_msb = phy_addr ? upper_32_bits(phy_addr) : 0;

	dev_dbg(sev->dev, "sev command id %#x buffer 0x%08x%08x timeout %us\n",
		cmd, phys_msb, phys_lsb, *hygon_psp_hooks.psp_timeout);

	print_hex_dump_debug("(in):  ", DUMP_PREFIX_OFFSET, 16, 2, data,
			     hygon_psp_hooks.sev_cmd_buffer_len(cmd), false);

	iowrite32(phys_lsb, sev->io_regs + sev->vdata->cmdbuff_addr_lo_reg);
	iowrite32(phys_msb, sev->io_regs + sev->vdata->cmdbuff_addr_hi_reg);

@@ -1000,13 +992,10 @@ int __vpsp_do_cmd_locked(uint32_t vid, int cmd, void *data, int *psp_ret)
		ret = -EIO;
	}

	print_hex_dump_debug("(out): ", DUMP_PREFIX_OFFSET, 16, 2, data,
			     hygon_psp_hooks.sev_cmd_buffer_len(cmd), false);

	return ret;
}

int vpsp_do_cmd(uint32_t vid, int cmd, void *data, int *psp_ret)
int vpsp_do_cmd(int cmd, phys_addr_t phy_addr, int *psp_ret)
{
	int rc;
	int mutex_enabled = READ_ONCE(hygon_psp_hooks.psp_mutex_enabled);
@@ -1020,7 +1009,7 @@ int vpsp_do_cmd(uint32_t vid, int cmd, void *data, int *psp_ret)
		mutex_lock(hygon_psp_hooks.sev_cmd_mutex);
	}

	rc = __vpsp_do_cmd_locked(vid, cmd, data, psp_ret);
	rc = __vpsp_do_cmd_locked(cmd, phy_addr, psp_ret);

	if (is_vendor_hygon() && mutex_enabled)
		psp_mutex_unlock(&hygon_psp_hooks.psp_misc->data_pg_aligned->mb_mutex);
@@ -1034,7 +1023,7 @@ int vpsp_do_cmd(uint32_t vid, int cmd, void *data, int *psp_ret)
 * Try to obtain the result again by the command index, this
 * interface is used in ringbuffer mode
 */
int vpsp_try_get_result(uint32_t vid, uint8_t prio, uint32_t index, void *data,
int vpsp_try_get_result(uint8_t prio, uint32_t index, phys_addr_t phy_addr,
		struct vpsp_ret *psp_ret)
{
	int ret = 0;
@@ -1054,8 +1043,7 @@ int vpsp_try_get_result(uint32_t vid, uint8_t prio, uint32_t index, void *data,
			/* dequeue command from queue*/
			vpsp_dequeue_cmd(prio, index, &cmd);

			ret = __vpsp_do_cmd_locked(vid, cmd.cmd_id, data,
					(int *)psp_ret);
			ret = __vpsp_do_cmd_locked(cmd.cmd_id, phy_addr, (int *)psp_ret);
			psp_ret->status = VPSP_FINISH;
			vpsp_psp_mutex_unlock();
			if (unlikely(ret)) {
@@ -1098,7 +1086,7 @@ EXPORT_SYMBOL_GPL(vpsp_try_get_result);
 * vpsp_try_get_result interface will be used to obtain the result
 * later again
 */
int vpsp_try_do_cmd(uint32_t vid, int cmd, void *data, struct vpsp_ret *psp_ret)
int vpsp_try_do_cmd(int cmd, phys_addr_t phy_addr, struct vpsp_ret *psp_ret)
{
	int ret = 0;
	int rb_supported;
@@ -1110,10 +1098,10 @@ int vpsp_try_do_cmd(uint32_t vid, int cmd, void *data, struct vpsp_ret *psp_ret)
			(struct vpsp_cmd *)&cmd);
	if (rb_supported) {
		/* fill command in ringbuffer's queue and get index */
		index = vpsp_fill_cmd_queue(vid, prio, cmd, data, 0);
		index = vpsp_fill_cmd_queue(prio, cmd, phy_addr, 0);
		if (unlikely(index < 0)) {
			/* do mailbox command if queuing failed*/
			ret = vpsp_do_cmd(vid, cmd, data, (int *)psp_ret);
			ret = vpsp_do_cmd(cmd, phy_addr, (int *)psp_ret);
			if (unlikely(ret)) {
				if (ret == -EIO) {
					ret = 0;
@@ -1129,14 +1117,14 @@ int vpsp_try_do_cmd(uint32_t vid, int cmd, void *data, struct vpsp_ret *psp_ret)
		}

		/* try to get result from the ringbuffer command */
		ret = vpsp_try_get_result(vid, prio, index, data, psp_ret);
		ret = vpsp_try_get_result(prio, index, phy_addr, psp_ret);
		if (unlikely(ret)) {
			pr_err("[%s]: vpsp_try_get_result failed %d\n", __func__, ret);
			goto end;
		}
	} else {
		/* mailbox mode */
		ret = vpsp_do_cmd(vid, cmd, data, (int *)psp_ret);
		ret = vpsp_do_cmd(cmd, phy_addr, (int *)psp_ret);
		if (unlikely(ret)) {
			if (ret == -EIO) {
				ret = 0;
+56 −30
Original line number Diff line number Diff line
@@ -38,16 +38,26 @@ enum VPSP_DEV_CTRL_OPCODE {
	VPSP_OP_VID_DEL,
	VPSP_OP_SET_DEFAULT_VID_PERMISSION,
	VPSP_OP_GET_DEFAULT_VID_PERMISSION,
	VPSP_OP_SET_GPA,
};

struct vpsp_dev_ctrl {
	unsigned char op;
	/**
	 * To be compatible with old user mode,
	 * struct vpsp_dev_ctrl must be kept at 132 bytes.
	 */
	unsigned char resv[3];
	union {
		unsigned int vid;
		// Set or check the permissions for the default VID
		unsigned int def_vid_perm;
		struct {
			u64 gpa_start;
			u64 gpa_end;
		} gpa;
		unsigned char reserved[128];
	} data;
	} __packed data;
};

uint64_t atomic64_exchange(uint64_t *dst, uint64_t val)
@@ -160,19 +170,15 @@ DEFINE_RWLOCK(vpsp_rwlock);
#define VPSP_VID_MAX_ENTRIES    2048
#define VPSP_VID_NUM_MAX        64

struct vpsp_vid_entry {
	uint32_t vid;
	pid_t pid;
};
static struct vpsp_vid_entry g_vpsp_vid_array[VPSP_VID_MAX_ENTRIES];
static struct vpsp_context g_vpsp_context_array[VPSP_VID_MAX_ENTRIES];
static uint32_t g_vpsp_vid_num;
static int compare_vid_entries(const void *a, const void *b)
{
	return ((struct vpsp_vid_entry *)a)->pid - ((struct vpsp_vid_entry *)b)->pid;
	return ((struct vpsp_context *)a)->pid - ((struct vpsp_context *)b)->pid;
}
static void swap_vid_entries(void *a, void *b, int size)
{
	struct vpsp_vid_entry entry;
	struct vpsp_context entry;

	memcpy(&entry, a, size);
	memcpy(a, b, size);
@@ -197,43 +203,41 @@ int vpsp_get_default_vid_permission(void)
EXPORT_SYMBOL_GPL(vpsp_get_default_vid_permission);

/**
 * When the virtual machine executes the 'tkm' command,
 * it needs to retrieve the corresponding 'vid'
 * by performing a binary search using 'kvm->userspace_pid'.
 * get a vpsp context from pid
 */
int vpsp_get_vid(uint32_t *vid, pid_t pid)
int vpsp_get_context(struct vpsp_context **ctx, pid_t pid)
{
	struct vpsp_vid_entry new_entry = {.pid = pid};
	struct vpsp_vid_entry *existing_entry = NULL;
	struct vpsp_context new_entry = {.pid = pid};
	struct vpsp_context *existing_entry = NULL;

	read_lock(&vpsp_rwlock);
	existing_entry = bsearch(&new_entry, g_vpsp_vid_array, g_vpsp_vid_num,
				sizeof(struct vpsp_vid_entry), compare_vid_entries);
	existing_entry = bsearch(&new_entry, g_vpsp_context_array, g_vpsp_vid_num,
				sizeof(struct vpsp_context), compare_vid_entries);
	read_unlock(&vpsp_rwlock);

	if (!existing_entry)
		return -ENOENT;
	if (vid) {
		*vid = existing_entry->vid;
		pr_debug("PSP: %s %d, by pid %d\n", __func__, *vid, pid);
	}

	if (ctx)
		*ctx = existing_entry;

	return 0;
}
EXPORT_SYMBOL_GPL(vpsp_get_vid);
EXPORT_SYMBOL_GPL(vpsp_get_context);

/**
 * Upon qemu startup, this section checks whether
 * the '-device psp,vid' parameter is specified.
 * If set, it utilizes the 'vpsp_add_vid' function
 * to insert the 'vid' and 'pid' values into the 'g_vpsp_vid_array'.
 * to insert the 'vid' and 'pid' values into the 'g_vpsp_context_array'.
 * The insertion is done in ascending order of 'pid'.
 */
static int vpsp_add_vid(uint32_t vid)
{
	pid_t cur_pid = task_pid_nr(current);
	struct vpsp_vid_entry new_entry = {.vid = vid, .pid = cur_pid};
	struct vpsp_context new_entry = {.vid = vid, .pid = cur_pid};

	if (vpsp_get_vid(NULL, cur_pid) == 0)
	if (vpsp_get_context(NULL, cur_pid) == 0)
		return -EEXIST;
	if (g_vpsp_vid_num == VPSP_VID_MAX_ENTRIES)
		return -ENOMEM;
@@ -241,8 +245,8 @@ static int vpsp_add_vid(uint32_t vid)
		return -EINVAL;

	write_lock(&vpsp_rwlock);
	memcpy(&g_vpsp_vid_array[g_vpsp_vid_num++], &new_entry, sizeof(struct vpsp_vid_entry));
	sort(g_vpsp_vid_array, g_vpsp_vid_num, sizeof(struct vpsp_vid_entry),
	memcpy(&g_vpsp_context_array[g_vpsp_vid_num++], &new_entry, sizeof(struct vpsp_context));
	sort(g_vpsp_context_array, g_vpsp_vid_num, sizeof(struct vpsp_context),
				compare_vid_entries, swap_vid_entries);
	pr_info("PSP: add vid %d, by pid %d, total vid num is %d\n", vid, cur_pid, g_vpsp_vid_num);
	write_unlock(&vpsp_rwlock);
@@ -261,12 +265,12 @@ static int vpsp_del_vid(void)

	write_lock(&vpsp_rwlock);
	for (i = 0; i < g_vpsp_vid_num; ++i) {
		if (g_vpsp_vid_array[i].pid == cur_pid) {
		if (g_vpsp_context_array[i].pid == cur_pid) {
			--g_vpsp_vid_num;
			pr_info("PSP: delete vid %d, by pid %d, total vid num is %d\n",
				g_vpsp_vid_array[i].vid, cur_pid, g_vpsp_vid_num);
			memmove(&g_vpsp_vid_array[i], &g_vpsp_vid_array[i + 1],
				sizeof(struct vpsp_vid_entry) * (g_vpsp_vid_num - i));
				g_vpsp_context_array[i].vid, cur_pid, g_vpsp_vid_num);
			memmove(&g_vpsp_context_array[i], &g_vpsp_context_array[i + 1],
				sizeof(struct vpsp_context) * (g_vpsp_vid_num - i));
			ret = 0;
			goto end;
		}
@@ -277,6 +281,24 @@ static int vpsp_del_vid(void)
	return ret;
}

static int vpsp_set_gpa_range(u64 gpa_start, u64 gpa_end)
{
	pid_t cur_pid = task_pid_nr(current);
	struct vpsp_context *ctx = NULL;

	vpsp_get_context(&ctx, cur_pid);
	if (!ctx) {
		pr_err("PSP: %s get vpsp_context failed from pid %d\n", __func__, cur_pid);
		return -ENOENT;
	}

	ctx->gpa_start = gpa_start;
	ctx->gpa_end = gpa_end;
	pr_info("PSP: set gpa range (start 0x%llx, end 0x%llx), by pid %d\n",
		gpa_start, gpa_end, cur_pid);
	return 0;
}

static int do_vpsp_op_ioctl(struct vpsp_dev_ctrl *ctrl)
{
	int ret = 0;
@@ -299,6 +321,10 @@ static int do_vpsp_op_ioctl(struct vpsp_dev_ctrl *ctrl)
		ctrl->data.def_vid_perm = vpsp_get_default_vid_permission();
		break;

	case VPSP_OP_SET_GPA:
		ret = vpsp_set_gpa_range(ctrl->data.gpa.gpa_start, ctrl->data.gpa.gpa_end);
		break;

	default:
		ret = -EINVAL;
		break;
+292 −39
Original line number Diff line number Diff line
@@ -18,26 +18,32 @@
#undef pr_fmt
#endif
#define pr_fmt(fmt) "vpsp: " fmt
#define VTKM_VM_BIND	0x904

/*
 * The file mainly implements the base execution
 * logic of virtual PSP in kernel mode, which mainly includes:
 *	(1) Obtain the VM command and preprocess the pointer
 *		mapping table information in the command buffer
 *	(2) The command that has been converted will interact
 *		with the channel of the psp through the driver and
 *		try to obtain the execution result
 *	(3) The executed command data is recovered according to
 *		the multilevel pointer of the mapping table, and then returned to the VM
 * The file mainly implements the base execution logic of virtual PSP in kernel mode,
 *	which mainly includes:
 *	(1) Preprocess the guest data in the host kernel
 *	(2) The command that has been converted will interact with the channel of the
 *		psp through the driver and try to obtain the execution result
 *	(3) The executed command data is recovered, and then returned to the VM
 *
 * The primary implementation logic of virtual PSP in kernel mode
 * call trace:
 * guest command(vmmcall)
 *		   |-> kvm_pv_psp_cmd_pre_op
 * guest command(vmmcall, KVM_HC_PSP_COPY_FORWARD_OP)
 *		   |
 *	kvm_pv_psp_copy_op---->	| -> kvm_pv_psp_cmd_pre_op
 *				|
 *				| -> vpsp_try_do_cmd/vpsp_try_get_result
 *				|	|<=> psp device driver
 *				|
 *  kvm_pv_psp_copy_forward_op->|-> vpsp_try_do_cmd/vpsp_try_get_result <====> psp device driver
 *				|
 *				|-> kvm_pv_psp_cmd_post_op
 *
 * guest command(vmmcall, KVM_HC_PSP_FORWARD_OP)
 *		   |
 *	kvm_pv_psp_forward_op-> |-> vpsp_try_do_cmd/vpsp_try_get_result
 *					|<=> psp device driver
 */

struct psp_cmdresp_head {
@@ -56,10 +62,36 @@ struct vpsp_hbuf_wrapper {
struct vpsp_hbuf_wrapper
g_hbuf_wrap[CSV_COMMAND_PRIORITY_NUM][CSV_RING_BUFFER_SIZE / CSV_RING_BUFFER_ESIZE] = {0};

/*
 * Obtain the VM command and preprocess the pointer mapping table
 * information in the command buffer, the processed data will be
 * used to interact with the psp device
static int check_gpa_range(struct vpsp_context *vpsp_ctx, gpa_t addr, uint32_t size)
{
	if (!vpsp_ctx || !addr)
		return -EFAULT;

	if (addr >= vpsp_ctx->gpa_start && (addr + size) <= vpsp_ctx->gpa_end)
		return 0;
	return -EFAULT;
}

static int check_psp_mem_range(struct vpsp_context *vpsp_ctx,
			void *data, uint32_t size)
{
	if ((((uintptr_t)data + size - 1) & ~PSP_2MB_MASK) !=
			((uintptr_t)data & ~PSP_2MB_MASK)) {
		pr_err("data %llx, size %d crossing 2MB\n", (u64)data, size);
		return -EFAULT;
	}

	if (vpsp_ctx)
		return check_gpa_range(vpsp_ctx, (gpa_t)data, size);

	return 0;
}

/**
 * Copy the guest data to the host kernel buffer
 * and record the host buffer address in 'hbuf'.
 * This 'hbuf' is used to restore context information
 * during asynchronous processing.
 */
static int kvm_pv_psp_cmd_pre_op(struct kvm_vpsp *vpsp, gpa_t data_gpa,
		struct vpsp_hbuf_wrapper *hbuf)
@@ -74,11 +106,8 @@ static int kvm_pv_psp_cmd_pre_op(struct kvm_vpsp *vpsp, gpa_t data_gpa,
		return -EFAULT;

	data_size = psp_head.buf_size;
	if ((((uintptr_t)data_gpa + data_size - 1) & ~PSP_2MB_MASK)
			!= ((uintptr_t)data_gpa & ~PSP_2MB_MASK)) {
		pr_err("data_gpa %llx, data_size %d crossing 2MB\n", (u64)data_gpa, data_size);
	if (check_psp_mem_range(NULL, data_gpa, data_size))
		return -EFAULT;
	}

	data = kzalloc(data_size, GFP_KERNEL);
	if (!data)
@@ -122,9 +151,234 @@ static int cmd_type_is_tkm(int cmd)
	return 0;
}

static int cmd_type_is_allowed(int cmd)
{
	if (cmd >= TKM_PSP_CMDID_OFFSET && cmd <= TKM_CMD_ID_MAX)
		return 1;
	return 0;
}

struct psp_cmdresp_vtkm_vm_bind {
	struct psp_cmdresp_head head;
	uint16_t vid;
	uint32_t vm_handle;
	uint8_t reserved[46];
} __packed;

static int kvm_bind_vtkm(uint32_t vm_handle, uint32_t cmd_id, uint32_t vid, uint32_t *pret)
{
	int ret = 0;
	struct psp_cmdresp_vtkm_vm_bind *data;

	data = kzalloc(sizeof(*data), GFP_KERNEL);
	if (!data)
		return -ENOMEM;

	data->head.buf_size = sizeof(*data);
	data->head.cmdresp_size = sizeof(*data);
	data->head.cmdresp_code = VTKM_VM_BIND;
	data->vid = vid;
	data->vm_handle = vm_handle;

	ret = psp_do_cmd(cmd_id, data, pret);
	if (ret == -EIO)
		ret = 0;

	kfree(data);
	return ret;
}

static phys_addr_t gpa_to_hpa(struct kvm_vpsp *vpsp, unsigned long data_gpa)
{
	phys_addr_t hpa = 0;
	unsigned long pfn = vpsp->gfn_to_pfn(vpsp->kvm, data_gpa >> PAGE_SHIFT);

	if (!is_error_pfn(pfn))
		hpa = ((pfn << PAGE_SHIFT) + offset_in_page(data_gpa)) | sme_get_me_mask();

	pr_debug("gpa %lx, hpa %llx\n", data_gpa, hpa);
	return hpa;

}

static int check_cmd_forward_op_permission(struct kvm_vpsp *vpsp, struct vpsp_context *vpsp_ctx,
				uint64_t data, uint32_t cmd)
{
	int ret;
	struct vpsp_cmd *vcmd = (struct vpsp_cmd *)&cmd;
	struct psp_cmdresp_head psp_head;

	if (!cmd_type_is_allowed(vcmd->cmd_id)) {
		pr_err("[%s]: unsupported cmd id %x\n", __func__, vcmd->cmd_id);
		return -EINVAL;
	}

	if (vpsp->is_csv_guest) {
		/**
 * @brief kvm_pv_psp_copy_forward_op is used for ordinary virtual machines to copy data
 * in gpa to host memory and send it to psp for processing.
		 * If the gpa address range exists,
		 * it means there must be a legal vid
		 */
		if (!vpsp_ctx || !vpsp_ctx->gpa_start || !vpsp_ctx->gpa_end) {
			pr_err("[%s]: No set gpa range or vid in csv guest\n", __func__);
			return -EPERM;
		}

		ret = check_psp_mem_range(vpsp_ctx, (void *)data, 0);
		if (ret)
			return -EFAULT;
	} else {
		if (!vpsp_ctx && cmd_type_is_tkm(vcmd->cmd_id)
				&& !vpsp_get_default_vid_permission()) {
			pr_err("[%s]: not allowed tkm command without vid\n", __func__);
			return -EPERM;
		}

		// the 'data' is gpa address
		if (unlikely(vpsp->read_guest(vpsp->kvm, data, &psp_head,
					sizeof(struct psp_cmdresp_head))))
			return -EFAULT;

		ret = check_psp_mem_range(vpsp_ctx, (void *)data, psp_head.buf_size);
		if (ret)
			return -EFAULT;
	}
	return 0;
}

static int
check_cmd_copy_forward_op_permission(struct kvm_vpsp *vpsp,
				struct vpsp_context *vpsp_ctx,
				uint64_t data, uint32_t cmd)
{
	int ret = 0;
	struct vpsp_cmd *vcmd = (struct vpsp_cmd *)&cmd;

	if (!cmd_type_is_allowed(vcmd->cmd_id)) {
		pr_err("[%s]: unsupported cmd id %x\n", __func__, vcmd->cmd_id);
		return -EINVAL;
	}

	if (vpsp->is_csv_guest) {
		pr_err("[%s]: unsupported run on csv guest\n", __func__);
		ret = -EPERM;
	} else {
		if (!vpsp_ctx && cmd_type_is_tkm(vcmd->cmd_id)
				&& !vpsp_get_default_vid_permission()) {
			pr_err("[%s]: not allowed tkm command without vid\n", __func__);
			ret = -EPERM;
		}
	}
	return ret;
}

static int vpsp_try_bind_vtkm(struct kvm_vpsp *vpsp, struct vpsp_context *vpsp_ctx,
				uint32_t cmd, uint32_t *psp_ret)
{
	int ret;
	struct vpsp_cmd *vcmd = (struct vpsp_cmd *)&cmd;

	if (vpsp_ctx && !vpsp_ctx->vm_is_bound && vpsp->is_csv_guest) {
		ret = kvm_bind_vtkm(vpsp->vm_handle, vcmd->cmd_id,
					vpsp_ctx->vid, psp_ret);
		if (ret || *psp_ret) {
			pr_err("[%s] kvm bind vtkm failed with ret: %d, pspret: %d\n",
				__func__, ret, *psp_ret);
			return ret;
		}
		vpsp_ctx->vm_is_bound = 1;
	}
	return 0;
}

/**
 * @brief Directly convert the gpa address into hpa and forward it to PSP,
 *	  It is another form of kvm_pv_psp_copy_op, mainly used for csv VMs.
 *
 * @param vpsp points to kvm related data
 * @param cmd psp cmd id, bit 31 indicates queue priority
 * @param data_gpa guest physical address of input data
 * @param psp_ret indicates Asynchronous context information
 *
 * Since the csv guest memory cannot be read or written directly,
 * the shared asynchronous context information is shared through psp_ret and return value.
 */
int kvm_pv_psp_forward_op(struct kvm_vpsp *vpsp, uint32_t cmd,
			gpa_t data_gpa, uint32_t psp_ret)
{
	int ret;
	uint64_t data_hpa;
	uint32_t index = 0, vid = 0;
	struct vpsp_ret psp_async = {0};
	struct vpsp_context *vpsp_ctx = NULL;
	struct vpsp_cmd *vcmd = (struct vpsp_cmd *)&cmd;
	uint8_t prio = CSV_COMMAND_PRIORITY_LOW;

	vpsp_get_context(&vpsp_ctx, vpsp->kvm->userspace_pid);

	ret = check_cmd_forward_op_permission(vpsp, vpsp_ctx, data_gpa, cmd);
	if (unlikely(ret)) {
		pr_err("directly operation not allowed\n");
		goto end;
	}

	ret = vpsp_try_bind_vtkm(vpsp, vpsp_ctx, cmd, (uint32_t *)&psp_async);
	if (unlikely(ret || *(uint32_t *)&psp_async)) {
		pr_err("try to bind vtkm failed (ret %x, psp_async %x)\n",
			ret, *(uint32_t *)&psp_async);
		goto end;
	}

	if (vpsp_ctx)
		vid = vpsp_ctx->vid;

	*((uint32_t *)&psp_async) = psp_ret;
	data_hpa = PUT_PSP_VID(gpa_to_hpa(vpsp, data_gpa), vid);

	switch (psp_async.status) {
	case VPSP_INIT:
		/* try to send command to the device for execution*/
		ret = vpsp_try_do_cmd(cmd, data_hpa, &psp_async);
		if (unlikely(ret)) {
			pr_err("[%s]: vpsp_do_cmd failed\n", __func__);
			goto end;
		}
		break;

	case VPSP_RUNNING:
		prio = vcmd->is_high_rb ? CSV_COMMAND_PRIORITY_HIGH :
			CSV_COMMAND_PRIORITY_LOW;
		index = psp_async.index;
		/* try to get the execution result from ringbuffer*/
		ret = vpsp_try_get_result(prio, index, data_hpa, &psp_async);
		if (unlikely(ret)) {
			pr_err("[%s]: vpsp_try_get_result failed\n", __func__);
			goto end;
		}
		break;

	default:
		pr_err("[%s]: invalid command status\n", __func__);
		break;
	}

end:
	/**
	 * In order to indicate both system errors and PSP errors,
	 * the psp_async.pret field needs to be reused.
	 */
	psp_async.format = VPSP_RET_PSP_FORMAT;
	if (ret) {
		psp_async.format = VPSP_RET_SYS_FORMAT;
		if (ret > 0)
			ret = -ret;
		psp_async.pret = (uint16_t)ret;
	}
	return *((int *)&psp_async);
}
EXPORT_SYMBOL_GPL(kvm_pv_psp_forward_op);

/**
 * @brief copy data in gpa to host memory and send it to psp for processing.
 *
 * @param vpsp points to kvm related data
 * @param cmd psp cmd id, bit 31 indicates queue priority
@@ -137,24 +391,22 @@ int kvm_pv_psp_copy_forward_op(struct kvm_vpsp *vpsp, int cmd, gpa_t data_gpa, g
	struct vpsp_ret psp_ret = {0};
	struct vpsp_hbuf_wrapper hbuf = {0};
	struct vpsp_cmd *vcmd = (struct vpsp_cmd *)&cmd;
	struct vpsp_context *vpsp_ctx = NULL;
	phys_addr_t data_paddr = 0;
	uint8_t prio = CSV_COMMAND_PRIORITY_LOW;
	uint32_t index = 0;
	uint32_t vid = 0;

	if (vcmd->cmd_id != TKM_PSP_CMDID_OFFSET) {
		pr_err("[%s]: unsupported cmd id %x\n", __func__, vcmd->cmd_id);
		return -EINVAL;
	}
	vpsp_get_context(&vpsp_ctx, vpsp->kvm->userspace_pid);

	// only tkm cmd need vid
	if (cmd_type_is_tkm(vcmd->cmd_id)) {
		// check the permission to use the default vid when no vid is set
		ret = vpsp_get_vid(&vid, vpsp->kvm->userspace_pid);
		if (ret && !vpsp_get_default_vid_permission()) {
			pr_err("[%s]: not allowed tkm command without vid\n", __func__);
	ret = check_cmd_copy_forward_op_permission(vpsp, vpsp_ctx, data_gpa, cmd);
	if (unlikely(ret)) {
		pr_err("copy operation not allowed\n");
		return -EPERM;
	}
	}

	if (vpsp_ctx)
		vid = vpsp_ctx->vid;

	if (unlikely(vpsp->read_guest(vpsp->kvm, psp_ret_gpa, &psp_ret,
					sizeof(psp_ret))))
@@ -172,9 +424,9 @@ int kvm_pv_psp_copy_forward_op(struct kvm_vpsp *vpsp, int cmd, gpa_t data_gpa, g
			goto end;
		}

		data_paddr = PUT_PSP_VID(__psp_pa(hbuf.data), vid);
		/* try to send command to the device for execution*/
		ret = vpsp_try_do_cmd(vid, cmd, (void *)hbuf.data,
				(struct vpsp_ret *)&psp_ret);
		ret = vpsp_try_do_cmd(cmd, data_paddr, (struct vpsp_ret *)&psp_ret);
		if (unlikely(ret)) {
			pr_err("[%s]: vpsp_try_do_cmd failed\n", __func__);
			ret = -EFAULT;
@@ -202,8 +454,9 @@ int kvm_pv_psp_copy_forward_op(struct kvm_vpsp *vpsp, int cmd, gpa_t data_gpa, g
		prio = vcmd->is_high_rb ? CSV_COMMAND_PRIORITY_HIGH :
			CSV_COMMAND_PRIORITY_LOW;
		index = psp_ret.index;
		data_paddr = PUT_PSP_VID(__psp_pa(g_hbuf_wrap[prio][index].data), vid);
		/* try to get the execution result from ringbuffer*/
		ret = vpsp_try_get_result(vid, prio, index, g_hbuf_wrap[prio][index].data,
		ret = vpsp_try_get_result(prio, index, data_paddr,
					(struct vpsp_ret *)&psp_ret);
		if (unlikely(ret)) {
			pr_err("[%s]: vpsp_try_get_result failed\n", __func__);
Loading