Commit fd94fcf0 authored by Nathan Huckleberry's avatar Nathan Huckleberry Committed by Herbert Xu
Browse files

crypto: x86/aesni-xctr - Add accelerated implementation of XCTR

Add hardware accelerated version of XCTR for x86-64 CPUs with AESNI
support.

More information on XCTR can be found in the HCTR2 paper:
"Length-preserving encryption with HCTR2":
https://eprint.iacr.org/2021/1441.pdf



Signed-off-by: default avatarNathan Huckleberry <nhuck@google.com>
Reviewed-by: default avatarArd Biesheuvel <ardb@kernel.org>
Reviewed-by: default avatarEric Biggers <ebiggers@google.com>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 7ff554ce
Loading
Loading
Loading
Loading
+152 −80
Original line number Diff line number Diff line
@@ -23,6 +23,11 @@

#define VMOVDQ		vmovdqu

/*
 * Note: the "x" prefix in these aliases means "this is an xmm register".  The
 * alias prefixes have no relation to XCTR where the "X" prefix means "XOR
 * counter".
 */
#define xdata0		%xmm0
#define xdata1		%xmm1
#define xdata2		%xmm2
@@ -31,8 +36,10 @@
#define xdata5		%xmm5
#define xdata6		%xmm6
#define xdata7		%xmm7
#define xcounter	%xmm8
#define xbyteswap	%xmm9
#define xcounter	%xmm8	// CTR mode only
#define xiv		%xmm8	// XCTR mode only
#define xbyteswap	%xmm9	// CTR mode only
#define xtmp		%xmm9	// XCTR mode only
#define xkey0		%xmm10
#define xkey4		%xmm11
#define xkey8		%xmm12
@@ -45,7 +52,7 @@
#define p_keys		%rdx
#define p_out		%rcx
#define num_bytes	%r8

#define counter		%r9	// XCTR mode only
#define tmp		%r10
#define	DDQ_DATA	0
#define	XDATA		1
@@ -102,7 +109,7 @@ ddq_add_8:
 * do_aes num_in_par load_keys key_len
 * This increments p_in, but not p_out
 */
.macro do_aes b, k, key_len
.macro do_aes b, k, key_len, xctr
	.set by, \b
	.set load_keys, \k
	.set klen, \key_len
@@ -111,8 +118,22 @@ ddq_add_8:
		vmovdqa	0*16(p_keys), xkey0
	.endif

	.if \xctr
		movq counter, xtmp
		.set i, 0
		.rept (by)
			club XDATA, i
			vpaddq	(ddq_add_1 + 16 * i)(%rip), xtmp, var_xdata
			.set i, (i +1)
		.endr
		.set i, 0
		.rept (by)
			club	XDATA, i
			vpxor	xiv, var_xdata, var_xdata
			.set i, (i +1)
		.endr
	.else
		vpshufb	xbyteswap, xcounter, xdata0

		.set i, 1
		.rept (by - 1)
			club XDATA, i
@@ -125,15 +146,20 @@ ddq_add_8:
			vpshufb	xbyteswap, var_xdata, var_xdata
			.set i, (i +1)
		.endr
	.endif

	vmovdqa	1*16(p_keys), xkeyA

	vpxor	xkey0, xdata0, xdata0
	.if \xctr
		add $by, counter
	.else
		vpaddq	(ddq_add_1 + 16 * (by - 1))(%rip), xcounter, xcounter
		vptest	ddq_low_msk(%rip), xcounter
		jnz	1f
		vpaddq	ddq_high_add_1(%rip), xcounter, xcounter
		1:
	.endif

	.set i, 1
	.rept (by - 1)
@@ -371,94 +397,99 @@ ddq_add_8:
	.endr
.endm

.macro do_aes_load val, key_len
	do_aes \val, 1, \key_len
.macro do_aes_load val, key_len, xctr
	do_aes \val, 1, \key_len, \xctr
.endm

.macro do_aes_noload val, key_len
	do_aes \val, 0, \key_len
.macro do_aes_noload val, key_len, xctr
	do_aes \val, 0, \key_len, \xctr
.endm

/* main body of aes ctr load */

.macro do_aes_ctrmain key_len
.macro do_aes_ctrmain key_len, xctr
	cmp	$16, num_bytes
	jb	.Ldo_return2\key_len
	jb	.Ldo_return2\xctr\key_len

	.if \xctr
		shr	$4, counter
		vmovdqu	(p_iv), xiv
	.else
		vmovdqa	byteswap_const(%rip), xbyteswap
		vmovdqu	(p_iv), xcounter
		vpshufb	xbyteswap, xcounter, xcounter
	.endif

	mov	num_bytes, tmp
	and	$(7*16), tmp
	jz	.Lmult_of_8_blks\key_len
	jz	.Lmult_of_8_blks\xctr\key_len

	/* 1 <= tmp <= 7 */
	cmp	$(4*16), tmp
	jg	.Lgt4\key_len
	je	.Leq4\key_len
	jg	.Lgt4\xctr\key_len
	je	.Leq4\xctr\key_len

.Llt4\key_len:
.Llt4\xctr\key_len:
	cmp	$(2*16), tmp
	jg	.Leq3\key_len
	je	.Leq2\key_len
	jg	.Leq3\xctr\key_len
	je	.Leq2\xctr\key_len

.Leq1\key_len:
	do_aes_load	1, \key_len
.Leq1\xctr\key_len:
	do_aes_load	1, \key_len, \xctr
	add	$(1*16), p_out
	and	$(~7*16), num_bytes
	jz	.Ldo_return2\key_len
	jmp	.Lmain_loop2\key_len
	jz	.Ldo_return2\xctr\key_len
	jmp	.Lmain_loop2\xctr\key_len

.Leq2\key_len:
	do_aes_load	2, \key_len
.Leq2\xctr\key_len:
	do_aes_load	2, \key_len, \xctr
	add	$(2*16), p_out
	and	$(~7*16), num_bytes
	jz	.Ldo_return2\key_len
	jmp	.Lmain_loop2\key_len
	jz	.Ldo_return2\xctr\key_len
	jmp	.Lmain_loop2\xctr\key_len


.Leq3\key_len:
	do_aes_load	3, \key_len
.Leq3\xctr\key_len:
	do_aes_load	3, \key_len, \xctr
	add	$(3*16), p_out
	and	$(~7*16), num_bytes
	jz	.Ldo_return2\key_len
	jmp	.Lmain_loop2\key_len
	jz	.Ldo_return2\xctr\key_len
	jmp	.Lmain_loop2\xctr\key_len

.Leq4\key_len:
	do_aes_load	4, \key_len
.Leq4\xctr\key_len:
	do_aes_load	4, \key_len, \xctr
	add	$(4*16), p_out
	and	$(~7*16), num_bytes
	jz	.Ldo_return2\key_len
	jmp	.Lmain_loop2\key_len
	jz	.Ldo_return2\xctr\key_len
	jmp	.Lmain_loop2\xctr\key_len

.Lgt4\key_len:
.Lgt4\xctr\key_len:
	cmp	$(6*16), tmp
	jg	.Leq7\key_len
	je	.Leq6\key_len
	jg	.Leq7\xctr\key_len
	je	.Leq6\xctr\key_len

.Leq5\key_len:
	do_aes_load	5, \key_len
.Leq5\xctr\key_len:
	do_aes_load	5, \key_len, \xctr
	add	$(5*16), p_out
	and	$(~7*16), num_bytes
	jz	.Ldo_return2\key_len
	jmp	.Lmain_loop2\key_len
	jz	.Ldo_return2\xctr\key_len
	jmp	.Lmain_loop2\xctr\key_len

.Leq6\key_len:
	do_aes_load	6, \key_len
.Leq6\xctr\key_len:
	do_aes_load	6, \key_len, \xctr
	add	$(6*16), p_out
	and	$(~7*16), num_bytes
	jz	.Ldo_return2\key_len
	jmp	.Lmain_loop2\key_len
	jz	.Ldo_return2\xctr\key_len
	jmp	.Lmain_loop2\xctr\key_len

.Leq7\key_len:
	do_aes_load	7, \key_len
.Leq7\xctr\key_len:
	do_aes_load	7, \key_len, \xctr
	add	$(7*16), p_out
	and	$(~7*16), num_bytes
	jz	.Ldo_return2\key_len
	jmp	.Lmain_loop2\key_len
	jz	.Ldo_return2\xctr\key_len
	jmp	.Lmain_loop2\xctr\key_len

.Lmult_of_8_blks\key_len:
.Lmult_of_8_blks\xctr\key_len:
	.if (\key_len != KEY_128)
		vmovdqa	0*16(p_keys), xkey0
		vmovdqa	4*16(p_keys), xkey4
@@ -471,17 +502,19 @@ ddq_add_8:
		vmovdqa	9*16(p_keys), xkey12
	.endif
.align 16
.Lmain_loop2\key_len:
.Lmain_loop2\xctr\key_len:
	/* num_bytes is a multiple of 8 and >0 */
	do_aes_noload	8, \key_len
	do_aes_noload	8, \key_len, \xctr
	add	$(8*16), p_out
	sub	$(8*16), num_bytes
	jne	.Lmain_loop2\key_len
	jne	.Lmain_loop2\xctr\key_len

.Ldo_return2\key_len:
.Ldo_return2\xctr\key_len:
	.if !\xctr
		/* return updated IV */
		vpshufb	xbyteswap, xcounter, xcounter
		vmovdqu	xcounter, (p_iv)
	.endif
	RET
.endm

@@ -494,7 +527,7 @@ ddq_add_8:
 */
SYM_FUNC_START(aes_ctr_enc_128_avx_by8)
	/* call the aes main loop */
	do_aes_ctrmain KEY_128
	do_aes_ctrmain KEY_128 0

SYM_FUNC_END(aes_ctr_enc_128_avx_by8)

@@ -507,7 +540,7 @@ SYM_FUNC_END(aes_ctr_enc_128_avx_by8)
 */
SYM_FUNC_START(aes_ctr_enc_192_avx_by8)
	/* call the aes main loop */
	do_aes_ctrmain KEY_192
	do_aes_ctrmain KEY_192 0

SYM_FUNC_END(aes_ctr_enc_192_avx_by8)

@@ -520,6 +553,45 @@ SYM_FUNC_END(aes_ctr_enc_192_avx_by8)
 */
SYM_FUNC_START(aes_ctr_enc_256_avx_by8)
	/* call the aes main loop */
	do_aes_ctrmain KEY_256
	do_aes_ctrmain KEY_256 0

SYM_FUNC_END(aes_ctr_enc_256_avx_by8)

/*
 * routine to do AES128 XCTR enc/decrypt "by8"
 * XMM registers are clobbered.
 * Saving/restoring must be done at a higher level
 * aes_xctr_enc_128_avx_by8(const u8 *in, const u8 *iv, const void *keys,
 * 	u8* out, unsigned int num_bytes, unsigned int byte_ctr)
 */
SYM_FUNC_START(aes_xctr_enc_128_avx_by8)
	/* call the aes main loop */
	do_aes_ctrmain KEY_128 1

SYM_FUNC_END(aes_xctr_enc_128_avx_by8)

/*
 * routine to do AES192 XCTR enc/decrypt "by8"
 * XMM registers are clobbered.
 * Saving/restoring must be done at a higher level
 * aes_xctr_enc_192_avx_by8(const u8 *in, const u8 *iv, const void *keys,
 * 	u8* out, unsigned int num_bytes, unsigned int byte_ctr)
 */
SYM_FUNC_START(aes_xctr_enc_192_avx_by8)
	/* call the aes main loop */
	do_aes_ctrmain KEY_192 1

SYM_FUNC_END(aes_xctr_enc_192_avx_by8)

/*
 * routine to do AES256 XCTR enc/decrypt "by8"
 * XMM registers are clobbered.
 * Saving/restoring must be done at a higher level
 * aes_xctr_enc_256_avx_by8(const u8 *in, const u8 *iv, const void *keys,
 * 	u8* out, unsigned int num_bytes, unsigned int byte_ctr)
 */
SYM_FUNC_START(aes_xctr_enc_256_avx_by8)
	/* call the aes main loop */
	do_aes_ctrmain KEY_256 1

SYM_FUNC_END(aes_xctr_enc_256_avx_by8)
+113 −1
Original line number Diff line number Diff line
@@ -135,6 +135,20 @@ asmlinkage void aes_ctr_enc_192_avx_by8(const u8 *in, u8 *iv,
		void *keys, u8 *out, unsigned int num_bytes);
asmlinkage void aes_ctr_enc_256_avx_by8(const u8 *in, u8 *iv,
		void *keys, u8 *out, unsigned int num_bytes);


asmlinkage void aes_xctr_enc_128_avx_by8(const u8 *in, const u8 *iv,
	const void *keys, u8 *out, unsigned int num_bytes,
	unsigned int byte_ctr);

asmlinkage void aes_xctr_enc_192_avx_by8(const u8 *in, const u8 *iv,
	const void *keys, u8 *out, unsigned int num_bytes,
	unsigned int byte_ctr);

asmlinkage void aes_xctr_enc_256_avx_by8(const u8 *in, const u8 *iv,
	const void *keys, u8 *out, unsigned int num_bytes,
	unsigned int byte_ctr);

/*
 * asmlinkage void aesni_gcm_init_avx_gen2()
 * gcm_data *my_ctx_data, context data
@@ -527,6 +541,59 @@ static int ctr_crypt(struct skcipher_request *req)
	return err;
}

static void aesni_xctr_enc_avx_tfm(struct crypto_aes_ctx *ctx, u8 *out,
				   const u8 *in, unsigned int len, u8 *iv,
				   unsigned int byte_ctr)
{
	if (ctx->key_length == AES_KEYSIZE_128)
		aes_xctr_enc_128_avx_by8(in, iv, (void *)ctx, out, len,
					 byte_ctr);
	else if (ctx->key_length == AES_KEYSIZE_192)
		aes_xctr_enc_192_avx_by8(in, iv, (void *)ctx, out, len,
					 byte_ctr);
	else
		aes_xctr_enc_256_avx_by8(in, iv, (void *)ctx, out, len,
					 byte_ctr);
}

static int xctr_crypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
	u8 keystream[AES_BLOCK_SIZE];
	struct skcipher_walk walk;
	unsigned int nbytes;
	unsigned int byte_ctr = 0;
	int err;
	__le32 block[AES_BLOCK_SIZE / sizeof(__le32)];

	err = skcipher_walk_virt(&walk, req, false);

	while ((nbytes = walk.nbytes) > 0) {
		kernel_fpu_begin();
		if (nbytes & AES_BLOCK_MASK)
			aesni_xctr_enc_avx_tfm(ctx, walk.dst.virt.addr,
				walk.src.virt.addr, nbytes & AES_BLOCK_MASK,
				walk.iv, byte_ctr);
		nbytes &= ~AES_BLOCK_MASK;
		byte_ctr += walk.nbytes - nbytes;

		if (walk.nbytes == walk.total && nbytes > 0) {
			memcpy(block, walk.iv, AES_BLOCK_SIZE);
			block[0] ^= cpu_to_le32(1 + byte_ctr / AES_BLOCK_SIZE);
			aesni_enc(ctx, keystream, (u8 *)block);
			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes -
				       nbytes, walk.src.virt.addr + walk.nbytes
				       - nbytes, keystream, nbytes);
			byte_ctr += nbytes;
			nbytes = 0;
		}
		kernel_fpu_end();
		err = skcipher_walk_done(&walk, nbytes);
	}
	return err;
}

static int
rfc4106_set_hash_subkey(u8 *hash_subkey, const u8 *key, unsigned int key_len)
{
@@ -1050,6 +1117,33 @@ static struct skcipher_alg aesni_skciphers[] = {
static
struct simd_skcipher_alg *aesni_simd_skciphers[ARRAY_SIZE(aesni_skciphers)];

#ifdef CONFIG_X86_64
/*
 * XCTR does not have a non-AVX implementation, so it must be enabled
 * conditionally.
 */
static struct skcipher_alg aesni_xctr = {
	.base = {
		.cra_name		= "__xctr(aes)",
		.cra_driver_name	= "__xctr-aes-aesni",
		.cra_priority		= 400,
		.cra_flags		= CRYPTO_ALG_INTERNAL,
		.cra_blocksize		= 1,
		.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
		.cra_module		= THIS_MODULE,
	},
	.min_keysize	= AES_MIN_KEY_SIZE,
	.max_keysize	= AES_MAX_KEY_SIZE,
	.ivsize		= AES_BLOCK_SIZE,
	.chunksize	= AES_BLOCK_SIZE,
	.setkey		= aesni_skcipher_setkey,
	.encrypt	= xctr_crypt,
	.decrypt	= xctr_crypt,
};

static struct simd_skcipher_alg *aesni_simd_xctr;
#endif /* CONFIG_X86_64 */

#ifdef CONFIG_X86_64
static int generic_gcmaes_set_key(struct crypto_aead *aead, const u8 *key,
				  unsigned int key_len)
@@ -1163,7 +1257,7 @@ static int __init aesni_init(void)
		static_call_update(aesni_ctr_enc_tfm, aesni_ctr_enc_avx_tfm);
		pr_info("AES CTR mode by8 optimization enabled\n");
	}
#endif
#endif /* CONFIG_X86_64 */

	err = crypto_register_alg(&aesni_cipher_alg);
	if (err)
@@ -1180,8 +1274,22 @@ static int __init aesni_init(void)
	if (err)
		goto unregister_skciphers;

#ifdef CONFIG_X86_64
	if (boot_cpu_has(X86_FEATURE_AVX))
		err = simd_register_skciphers_compat(&aesni_xctr, 1,
						     &aesni_simd_xctr);
	if (err)
		goto unregister_aeads;
#endif /* CONFIG_X86_64 */

	return 0;

#ifdef CONFIG_X86_64
unregister_aeads:
	simd_unregister_aeads(aesni_aeads, ARRAY_SIZE(aesni_aeads),
				aesni_simd_aeads);
#endif /* CONFIG_X86_64 */

unregister_skciphers:
	simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
				  aesni_simd_skciphers);
@@ -1197,6 +1305,10 @@ static void __exit aesni_exit(void)
	simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
				  aesni_simd_skciphers);
	crypto_unregister_alg(&aesni_cipher_alg);
#ifdef CONFIG_X86_64
	if (boot_cpu_has(X86_FEATURE_AVX))
		simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr);
#endif /* CONFIG_X86_64 */
}

late_initcall(aesni_init);
+1 −1
Original line number Diff line number Diff line
@@ -1169,7 +1169,7 @@ config CRYPTO_AES_NI_INTEL
	  In addition to AES cipher algorithm support, the acceleration
	  for some popular block cipher mode is supported too, including
	  ECB, CBC, LRW, XTS. The 64 bit version has additional
	  acceleration for CTR.
	  acceleration for CTR and XCTR.

config CRYPTO_AES_SPARC64
	tristate "AES cipher algorithms (SPARC64)"