Unverified Commit 3a2df632 authored by Greentime Hu's avatar Greentime Hu Committed by Palmer Dabbelt
Browse files

riscv: Add task switch support for vector

parent 03c3fcd9
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ struct thread_struct {
	unsigned long s[12];	/* s[0]: frame pointer */
	struct __riscv_d_ext_state fstate;
	unsigned long bad_cause;
	struct __riscv_v_ext_state vstate;
};

/* Whitelist the fstate from the task_struct for hardened usercopy */
+3 −0
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@

#include <linux/jump_label.h>
#include <linux/sched/task_stack.h>
#include <asm/vector.h>
#include <asm/hwcap.h>
#include <asm/processor.h>
#include <asm/ptrace.h>
@@ -78,6 +79,8 @@ do { \
	struct task_struct *__next = (next);		\
	if (has_fpu())					\
		__switch_to_fpu(__prev, __next);	\
	if (has_vector())					\
		__switch_to_vector(__prev, __next);	\
	((last) = __switch_to(__prev, __next));		\
} while (0)

+3 −0
Original line number Diff line number Diff line
@@ -81,6 +81,9 @@ struct thread_info {
	.preempt_count	= INIT_PREEMPT_COUNT,	\
}

void arch_release_task_struct(struct task_struct *tsk);
int arch_dup_task_struct(struct task_struct *dst, struct task_struct *src);

#endif /* !__ASSEMBLY__ */

/*
+38 −0
Original line number Diff line number Diff line
@@ -12,6 +12,9 @@
#ifdef CONFIG_RISCV_ISA_V

#include <linux/stringify.h>
#include <linux/sched.h>
#include <linux/sched/task_stack.h>
#include <asm/ptrace.h>
#include <asm/hwcap.h>
#include <asm/csr.h>
#include <asm/asm.h>
@@ -124,6 +127,38 @@ static inline void __riscv_v_vstate_restore(struct __riscv_v_ext_state *restore_
	riscv_v_disable();
}

static inline void riscv_v_vstate_save(struct task_struct *task,
				       struct pt_regs *regs)
{
	if ((regs->status & SR_VS) == SR_VS_DIRTY) {
		struct __riscv_v_ext_state *vstate = &task->thread.vstate;

		__riscv_v_vstate_save(vstate, vstate->datap);
		__riscv_v_vstate_clean(regs);
	}
}

static inline void riscv_v_vstate_restore(struct task_struct *task,
					  struct pt_regs *regs)
{
	if ((regs->status & SR_VS) != SR_VS_OFF) {
		struct __riscv_v_ext_state *vstate = &task->thread.vstate;

		__riscv_v_vstate_restore(vstate, vstate->datap);
		__riscv_v_vstate_clean(regs);
	}
}

static inline void __switch_to_vector(struct task_struct *prev,
				      struct task_struct *next)
{
	struct pt_regs *regs;

	regs = task_pt_regs(prev);
	riscv_v_vstate_save(prev, regs);
	riscv_v_vstate_restore(next, task_pt_regs(next));
}

#else /* ! CONFIG_RISCV_ISA_V  */

struct pt_regs;
@@ -132,6 +167,9 @@ static inline int riscv_v_setup_vsize(void) { return -EOPNOTSUPP; }
static __always_inline bool has_vector(void) { return false; }
static inline bool riscv_v_vstate_query(struct pt_regs *regs) { return false; }
#define riscv_v_vsize (0)
#define riscv_v_vstate_save(task, regs)		do {} while (0)
#define riscv_v_vstate_restore(task, regs)	do {} while (0)
#define __switch_to_vector(__prev, __next)	do {} while (0)
#define riscv_v_vstate_off(regs)		do {} while (0)
#define riscv_v_vstate_on(regs)			do {} while (0)

+19 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@
#include <asm/switch_to.h>
#include <asm/thread_info.h>
#include <asm/cpuidle.h>
#include <asm/vector.h>

register unsigned long gp_in_global __asm__("gp");

@@ -146,12 +147,28 @@ void flush_thread(void)
	fstate_off(current, task_pt_regs(current));
	memset(&current->thread.fstate, 0, sizeof(current->thread.fstate));
#endif
#ifdef CONFIG_RISCV_ISA_V
	/* Reset vector state */
	riscv_v_vstate_off(task_pt_regs(current));
	kfree(current->thread.vstate.datap);
	memset(&current->thread.vstate, 0, sizeof(struct __riscv_v_ext_state));
#endif
}

void arch_release_task_struct(struct task_struct *tsk)
{
	/* Free the vector context of datap. */
	if (has_vector())
		kfree(tsk->thread.vstate.datap);
}

int arch_dup_task_struct(struct task_struct *dst, struct task_struct *src)
{
	fstate_save(src, task_pt_regs(src));
	*dst = *src;
	/* clear entire V context, including datap for a new task */
	memset(&dst->thread.vstate, 0, sizeof(struct __riscv_v_ext_state));

	return 0;
}

@@ -176,6 +193,8 @@ int copy_thread(struct task_struct *p, const struct kernel_clone_args *args)
		p->thread.s[1] = (unsigned long)args->fn_arg;
	} else {
		*childregs = *(current_pt_regs());
		/* Turn off status.VS */
		riscv_v_vstate_off(childregs);
		if (usp) /* User fork */
			childregs->sp = usp;
		if (clone_flags & CLONE_SETTLS)