Commit 23a251cc authored by Nathan Huckleberry's avatar Nathan Huckleberry Committed by Herbert Xu
Browse files

crypto: arm64/aes-xctr - Add accelerated implementation of XCTR

Add hardware accelerated version of XCTR for ARM64 CPUs with ARMv8
Crypto Extension support.  This XCTR implementation is based on the CTR
implementation in aes-modes.S.

More information on XCTR can be found in
the HCTR2 paper: "Length-preserving encryption with HCTR2":
https://eprint.iacr.org/2021/1441.pdf



Signed-off-by: default avatarNathan Huckleberry <nhuck@google.com>
Reviewed-by: default avatarArd Biesheuvel <ardb@kernel.org>
Reviewed-by: default avatarEric Biggers <ebiggers@google.com>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent fd94fcf0
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -96,13 +96,13 @@ config CRYPTO_AES_ARM64_CE_CCM
	select CRYPTO_LIB_AES

config CRYPTO_AES_ARM64_CE_BLK
	tristate "AES in ECB/CBC/CTR/XTS modes using ARMv8 Crypto Extensions"
	tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using ARMv8 Crypto Extensions"
	depends on KERNEL_MODE_NEON
	select CRYPTO_SKCIPHER
	select CRYPTO_AES_ARM64_CE

config CRYPTO_AES_ARM64_NEON_BLK
	tristate "AES in ECB/CBC/CTR/XTS modes using NEON instructions"
	tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using NEON instructions"
	depends on KERNEL_MODE_NEON
	select CRYPTO_SKCIPHER
	select CRYPTO_LIB_AES
+62 −2
Original line number Diff line number Diff line
@@ -34,10 +34,11 @@
#define aes_essiv_cbc_encrypt	ce_aes_essiv_cbc_encrypt
#define aes_essiv_cbc_decrypt	ce_aes_essiv_cbc_decrypt
#define aes_ctr_encrypt		ce_aes_ctr_encrypt
#define aes_xctr_encrypt	ce_aes_xctr_encrypt
#define aes_xts_encrypt		ce_aes_xts_encrypt
#define aes_xts_decrypt		ce_aes_xts_decrypt
#define aes_mac_update		ce_aes_mac_update
MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
#else
#define MODE			"neon"
#define PRIO			200
@@ -50,16 +51,18 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
#define aes_essiv_cbc_encrypt	neon_aes_essiv_cbc_encrypt
#define aes_essiv_cbc_decrypt	neon_aes_essiv_cbc_decrypt
#define aes_ctr_encrypt		neon_aes_ctr_encrypt
#define aes_xctr_encrypt	neon_aes_xctr_encrypt
#define aes_xts_encrypt		neon_aes_xts_encrypt
#define aes_xts_decrypt		neon_aes_xts_decrypt
#define aes_mac_update		neon_aes_mac_update
MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
#endif
#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
MODULE_ALIAS_CRYPTO("ecb(aes)");
MODULE_ALIAS_CRYPTO("cbc(aes)");
MODULE_ALIAS_CRYPTO("ctr(aes)");
MODULE_ALIAS_CRYPTO("xts(aes)");
MODULE_ALIAS_CRYPTO("xctr(aes)");
#endif
MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
@@ -89,6 +92,9 @@ asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
				int rounds, int bytes, u8 ctr[]);

asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
				 int rounds, int bytes, u8 ctr[], int byte_ctr);

asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
				int rounds, int bytes, u32 const rk2[], u8 iv[],
				int first);
@@ -442,6 +448,44 @@ static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
	return err ?: cbc_decrypt_walk(req, &walk);
}

static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
	int err, rounds = 6 + ctx->key_length / 4;
	struct skcipher_walk walk;
	unsigned int byte_ctr = 0;

	err = skcipher_walk_virt(&walk, req, false);

	while (walk.nbytes > 0) {
		const u8 *src = walk.src.virt.addr;
		unsigned int nbytes = walk.nbytes;
		u8 *dst = walk.dst.virt.addr;
		u8 buf[AES_BLOCK_SIZE];

		if (unlikely(nbytes < AES_BLOCK_SIZE))
			src = dst = memcpy(buf + sizeof(buf) - nbytes,
					   src, nbytes);
		else if (nbytes < walk.total)
			nbytes &= ~(AES_BLOCK_SIZE - 1);

		kernel_neon_begin();
		aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
						 walk.iv, byte_ctr);
		kernel_neon_end();

		if (unlikely(nbytes < AES_BLOCK_SIZE))
			memcpy(walk.dst.virt.addr,
			       buf + sizeof(buf) - nbytes, nbytes);
		byte_ctr += nbytes;

		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
	}

	return err;
}

static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -669,6 +713,22 @@ static struct skcipher_alg aes_algs[] = { {
	.setkey		= skcipher_aes_setkey,
	.encrypt	= ctr_encrypt,
	.decrypt	= ctr_encrypt,
}, {
	.base = {
		.cra_name		= "xctr(aes)",
		.cra_driver_name	= "xctr-aes-" MODE,
		.cra_priority		= PRIO,
		.cra_blocksize		= 1,
		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
		.cra_module		= THIS_MODULE,
	},
	.min_keysize	= AES_MIN_KEY_SIZE,
	.max_keysize	= AES_MAX_KEY_SIZE,
	.ivsize		= AES_BLOCK_SIZE,
	.chunksize	= AES_BLOCK_SIZE,
	.setkey		= skcipher_aes_setkey,
	.encrypt	= xctr_encrypt,
	.decrypt	= xctr_encrypt,
}, {
	.base = {
		.cra_name		= "xts(aes)",
+104 −62
Original line number Diff line number Diff line
@@ -318,38 +318,60 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.previous


	/*
	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int bytes, u8 ctr[])
	 * This macro generates the code for CTR and XCTR mode.
	 */

AES_FUNC_START(aes_ctr_encrypt)
.macro ctr_encrypt xctr
	stp		x29, x30, [sp, #-16]!
	mov		x29, sp

	enc_prepare	w3, x2, x12
	ld1		{vctr.16b}, [x5]

	.if \xctr
		umov		x12, vctr.d[0]
		lsr		w11, w6, #4
	.else
		umov		x12, vctr.d[1] /* keep swabbed ctr in reg */
		rev		x12, x12
	.endif

.LctrloopNx:
.LctrloopNx\xctr:
	add		w7, w4, #15
	sub		w4, w4, #MAX_STRIDE << 4
	lsr		w7, w7, #4
	mov		w8, #MAX_STRIDE
	cmp		w7, w8
	csel		w7, w7, w8, lt
	adds		x12, x12, x7

	.if \xctr
		add		x11, x11, x7
	.else
		adds		x12, x12, x7
	.endif
	mov		v0.16b, vctr.16b
	mov		v1.16b, vctr.16b
	mov		v2.16b, vctr.16b
	mov		v3.16b, vctr.16b
ST5(	mov		v4.16b, vctr.16b		)
	.if \xctr
		sub		x6, x11, #MAX_STRIDE - 1
		sub		x7, x11, #MAX_STRIDE - 2
		sub		x8, x11, #MAX_STRIDE - 3
		sub		x9, x11, #MAX_STRIDE - 4
ST5(		sub		x10, x11, #MAX_STRIDE - 5	)
		eor		x6, x6, x12
		eor		x7, x7, x12
		eor		x8, x8, x12
		eor		x9, x9, x12
ST5(		eor		x10, x10, x12			)
		mov		v0.d[0], x6
		mov		v1.d[0], x7
		mov		v2.d[0], x8
		mov		v3.d[0], x9
ST5(		mov		v4.d[0], x10			)
	.else
		bcs		0f

		.subsection	1
		/* apply carry to outgoing counter */
0:		umov		x8, vctr.d[0]
@@ -390,7 +412,8 @@ ST5( sub x10, x12, #MAX_STRIDE - 4 )
ST5(		rev		x10, x10			)
		mov		v3.d[1], x9
ST5(		mov		v4.d[1], x10			)
	tbnz		w4, #31, .Lctrtail
	.endif
	tbnz		w4, #31, .Lctrtail\xctr
	ld1		{v5.16b-v7.16b}, [x1], #48
ST4(	bl		aes_encrypt_block4x		)
ST5(	bl		aes_encrypt_block5x		)
@@ -403,16 +426,17 @@ ST5( ld1 {v5.16b-v6.16b}, [x1], #32 )
ST5(	eor		v4.16b, v6.16b, v4.16b		)
	st1		{v0.16b-v3.16b}, [x0], #64
ST5(	st1		{v4.16b}, [x0], #16		)
	cbz		w4, .Lctrout
	b		.LctrloopNx
	cbz		w4, .Lctrout\xctr
	b		.LctrloopNx\xctr

.Lctrout:
.Lctrout\xctr:
	.if !\xctr
		st1		{vctr.16b}, [x5] /* return next CTR value */
	.endif
	ldp		x29, x30, [sp], #16
	ret

.Lctrtail:
	/* XOR up to MAX_STRIDE * 16 - 1 bytes of in/output with v0 ... v3/v4 */
.Lctrtail\xctr:
	mov		x16, #16
	ands		x6, x4, #0xf
	csel		x13, x6, x16, ne
@@ -427,7 +451,7 @@ ST5( csel x14, x16, xzr, gt )

	adr_l		x12, .Lcts_permute_table
	add		x12, x12, x13
	ble		.Lctrtail1x
	ble		.Lctrtail1x\xctr

ST5(	ld1		{v5.16b}, [x1], x14		)
	ld1		{v6.16b}, [x1], x15
@@ -459,9 +483,9 @@ ST5( st1 {v5.16b}, [x0], x14 )
	add		x13, x13, x0
	st1		{v9.16b}, [x13]		// overlapping stores
	st1		{v8.16b}, [x0]
	b		.Lctrout
	b		.Lctrout\xctr

.Lctrtail1x:
.Lctrtail1x\xctr:
	sub		x7, x6, #16
	csel		x6, x6, x7, eq
	add		x1, x1, x6
@@ -476,9 +500,27 @@ ST5( mov v3.16b, v4.16b )
	eor		v5.16b, v5.16b, v3.16b
	bif		v5.16b, v6.16b, v11.16b
	st1		{v5.16b}, [x0]
	b		.Lctrout
	b		.Lctrout\xctr
.endm

	/*
	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int bytes, u8 ctr[])
	 */

AES_FUNC_START(aes_ctr_encrypt)
	ctr_encrypt 0
AES_FUNC_END(aes_ctr_encrypt)

	/*
	 * aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int bytes, u8 const iv[], int byte_ctr)
	 */

AES_FUNC_START(aes_xctr_encrypt)
	ctr_encrypt 1
AES_FUNC_END(aes_xctr_encrypt)


	/*
	 * aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,