Commit 3b77535a authored by leoliu-oc's avatar leoliu-oc
Browse files

Add support for Zhaoxin GMI SM4 Block Cipher algorithm

zhaoxin inclusion
category: feature
bugzilla: https://gitee.com/openeuler/kernel/issues/I8KTBP


CVE: NA

-----------------

This SM4 algorithm driver is developed to support the SM4 instruction,
making user develop their applications with both high performance and high
security.

Signed-off-by: default avatarleoliu-oc <leoliu-oc@zhaoxin.com>
parent 492d0f14
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
@@ -231,6 +231,25 @@ config CRYPTO_SM4_AESNI_AVX2_X86_64

	  If unsure, say N.

config CRYPTO_SM4_ZHAOXIN_GMI
	tristate "Ciphers: SM4 with modes: ECB, CBC, CTR, CFB, OFB (Zhaoxin GMI)"
	depends on X86 && CRYPTO
	select CRYPTO_SKCIPHER
	select CRYPTO_SIMD
	select CRYPTO_ALGAPI
	select CRYPTO_SM4
	help
	  SM4 cipher algorithms (Zhaoxin GMI Instruction).

	  SM4 (GBT.32907-2016) is a cryptographic standard issued by the
	  Organization of State Commercial Administration of China (OSCCA)
	  as an authorized cryptographic algorithms for the use within China.

	  This is SM4 optimized implementation using Zhaoxin GMI
	  instruction set for block cipher.

	  If unsure, say N.

config CRYPTO_TWOFISH_586
	tristate "Ciphers: Twofish (32-bit)"
	depends on (X86 || UML_X86) && !64BIT
+1 −0
Original line number Diff line number Diff line
@@ -110,6 +110,7 @@ obj-$(CONFIG_CRYPTO_ARIA_GFNI_AVX512_X86_64) += aria-gfni-avx512-x86_64.o
aria-gfni-avx512-x86_64-y := aria-gfni-avx512-asm_64.o aria_gfni_avx512_glue.o

obj-$(CONFIG_CRYPTO_SM3_ZHAOXIN_GMI) += sm3-zhaoxin-gmi.o
obj-$(CONFIG_CRYPTO_SM4_ZHAOXIN_GMI) += sm4-zhaoxin-gmi.o

quiet_cmd_perlasm = PERLASM $@
      cmd_perlasm = $(PERL) $< > $@
+860 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
/*
 * zhaoxin-gmi-sm4.c - wrapper code for Zhaoxin GMI.
 *
 * Copyright (C) 2023 Shanghai Zhaoxin Semiconductor LTD.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 */

#include <linux/types.h>
#include <linux/module.h>
#include <linux/err.h>
#include <crypto/cryptd.h>
#include <crypto/scatterwalk.h>
#include <crypto/algapi.h>
#include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h>
#include <linux/workqueue.h>
#include <crypto/sm4.h>
#include <asm/unaligned.h>
#include <linux/processor.h>
#include <linux/cpufeature.h>


#define SM4_ECB  (1<<6)
#define SM4_CBC  (1<<7)
#define SM4_CFB  (1<<8)
#define SM4_OFB  (1<<9)
#define SM4_CTR  (1<<10)

#define ZX_GMI_ALIGNMENT 16

#define GETU16(p)  ((u16)(p)[0]<<8 | (u16)(p)[1])

/* Control word. */
struct sm4_cipher_data {
	u8 iv[SM4_BLOCK_SIZE]; /* Initialization vector */
	union {
		u32 pad;
		struct {
			u32 encdec:1;
			u32 func:5;
			u32 mode:5;
			u32 digest:1;
		} b;
	} cword;                    /* Control word */
	struct sm4_ctx  keys;  /* Encryption key */
};


static u8 *rep_xcrypt(const u8 *input, u8 *output, void *key, u8 *iv,
							struct sm4_cipher_data *sm4_data, u64 count)
{
	unsigned long rax = sm4_data->cword.pad;

	// Set the flag for encryption or decryption
	if (sm4_data->cword.b.encdec == 1)
		rax &= ~0x01;
	else
		rax |= 0x01;

	__asm__ __volatile__(
		#ifdef __x86_64__
			"pushq %%rbp\n"
			"pushq %%rbx\n"
			"pushq %%rcx\n"
			"pushq %%rsi\n"
			"pushq %%rdi\n"
		#else
			"pushl %%ebp\n"
			"pushl %%ebx\n"
			"pushl %%ecx\n"
			"pushl %%esi\n"
			"pushl %%edi\n"
		#endif
		".byte 0xf3,0x0f,0xa7,0xf0\n"
		#ifdef __x86_64__
			"popq %%rdi\n"
			"popq %%rsi\n"
			"popq %%rcx\n"
			"popq %%rbx\n"
			"popq %%rbp\n"
		#else
			"popl %%edi\n"
			"popl %%esi\n"
			"popl %%ecx\n"
			"popl %%ebx\n"
			"popl %%ebp\n"
		#endif
		:
		: "S"(input), "D"(output), "a"(rax), "b"(key), "c"((unsigned long)count), "d"(iv));
	return iv;
}

static u8 *rep_xcrypt_ctr(const u8 *input, u8 *output, void *key, u8 *iv,
	struct sm4_cipher_data *sm4_data, u64 count)
{
	u8 oiv[SM4_BLOCK_SIZE] = {0};
	u16 cnt_tmp;
	u32 i;
	u8 *in_tmp = (u8 *)input, *out_tmp = output;

	//Backup the original IV if it is not NULL.
	if (iv)
		memcpy(oiv,  iv, SM4_BLOCK_SIZE);

	// Get the current counter.
	cnt_tmp = GETU16(&iv[14]);

	// Get the available counter space before overflow.
	cnt_tmp = 0x10000 - cnt_tmp;

	//
	// Check there is enough counter space for the required blocks.
	//
	if (cnt_tmp < count) {

		// Process the first part of data blocks.
		rep_xcrypt(in_tmp, out_tmp, key, iv, sm4_data, cnt_tmp);
		// Only increase the counter by SW when overflow occurs.
		memcpy(iv, oiv, SM4_BLOCK_SIZE);

		for (i = 0; i < cnt_tmp; i++)
			crypto_inc(iv, SM4_BLOCK_SIZE);

		out_tmp = output + cnt_tmp * SM4_BLOCK_SIZE;
		in_tmp = (u8 *)(input + cnt_tmp * SM4_BLOCK_SIZE);

		// Get the number of data blocks that have not been encrypted.
		cnt_tmp = count - cnt_tmp;
		// Process the remaining part of data blocks.
		rep_xcrypt(in_tmp, out_tmp, key, iv, sm4_data, cnt_tmp);
	} else {
		// Counter space is big enough, the counter will not overflow.
		rep_xcrypt(in_tmp, out_tmp, key, iv, sm4_data, count);
	}

	// Restore the iv if not null
	if (iv)
		memcpy(iv, oiv, SM4_BLOCK_SIZE);

	return iv;
}

static u8 *rep_xcrypt_ecb_ONE(const u8 *input, u8 *output, void *key,
						u8 *iv, struct sm4_cipher_data *sm4_data, u64 count)
{
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 1;
	cw.cword.pad     |= 0x20|SM4_ECB;

	return rep_xcrypt(input, output, key, iv, &cw, 1);
}

/**
 * gmi_sm4_set_key - Set the sm4 key.
 * @tfm:  The %crypto_skcipher that is used in the context.
 * @in_key: The input key.
 * @key_len:The size of the key.
 */
int gmi_sm4_set_key(struct crypto_skcipher  *tfm, const u8 *in_key,
					unsigned int key_len)
{
	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);

	if (key_len != SM4_KEY_SIZE) {
		pr_warn("The key_len must be 16 bytes. please check\n");
		return -EINVAL;
	}

	memcpy(ctx->rkey_enc, in_key, key_len);
	memcpy(ctx->rkey_dec, in_key, key_len);

	return 0;
}
EXPORT_SYMBOL_GPL(gmi_sm4_set_key);


static int sm4_cipher_common(struct skcipher_request *req, struct sm4_cipher_data *cw)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct skcipher_walk walk;
	unsigned int blocks;
	int err;
	u8 *iv;

	err = skcipher_walk_virt(&walk, req, true);

	while ((blocks = (walk.nbytes / SM4_BLOCK_SIZE))) {
		iv = rep_xcrypt(walk.src.virt.addr, walk.dst.virt.addr, ctx->rkey_enc,
						walk.iv, cw, blocks);

		err = skcipher_walk_done(&walk, walk.nbytes % SM4_BLOCK_SIZE);
	}

	return err;
}


static int ecb_encrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 1;
	cw.cword.pad     |= 0x20|SM4_ECB;

	err = sm4_cipher_common(req, &cw);

	return err;
}

static int ecb_decrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad  = 0;
	cw.cword.pad |= 0x20|SM4_ECB;

	err = sm4_cipher_common(req, &cw);

	return err;
}

static int cbc_encrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 1;
	cw.cword.pad     |= 0x20|SM4_CBC;

	err = sm4_cipher_common(req, &cw);

	return err;
}

static int cbc_decrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad  = 0;
	cw.cword.pad |= 0x20|SM4_CBC;

	err = sm4_cipher_common(req, &cw);

	return err;
}


/*
 * sm4_cipher_ctr is used for ZX-E and newer
 */
static int sm4_cipher_ctr(struct skcipher_request *req, struct sm4_cipher_data *cw)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct skcipher_walk walk;
	unsigned int blocks, nbytes;
	int err;
	u8 *iv, *dst, *src;
	u8 keystream[SM4_BLOCK_SIZE];
	u32 i;

	err = skcipher_walk_virt(&walk, req, true);

	while ((nbytes = walk.nbytes) > 0) {

		src = walk.src.virt.addr;
		dst = walk.dst.virt.addr;

		while (nbytes >= SM4_BLOCK_SIZE) {
			blocks = nbytes/SM4_BLOCK_SIZE;
			iv = rep_xcrypt_ctr(walk.src.virt.addr, walk.dst.virt.addr, ctx->rkey_enc,
				walk.iv, cw, blocks);

			for (i = 0; i < blocks; i++)
				crypto_inc(walk.iv, SM4_BLOCK_SIZE);

			dst += blocks * SM4_BLOCK_SIZE;
			src += blocks * SM4_BLOCK_SIZE;
			nbytes -= blocks * SM4_BLOCK_SIZE;
		}

		if (walk.nbytes == walk.total && nbytes > 0) {
			rep_xcrypt_ecb_ONE(walk.iv, keystream, ctx->rkey_enc, walk.iv, cw, 1);
			crypto_xor_cpy(dst, keystream, src, nbytes);
			dst += nbytes;
			src += nbytes;
			nbytes = 0;
		}

		err = skcipher_walk_done(&walk, nbytes);
	}

	return err;
}

/*
 *  ctr_encrypt is used for ZX-E and newer
 */
static int ctr_encrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 1;
	cw.cword.pad     |= 0x20|SM4_CTR;

	err = sm4_cipher_ctr(req, &cw);

	return err;
}

/*
 *  ctr_decrypt is used for ZX-E and newer
 */
static int ctr_decrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad  = 0;
	cw.cword.pad |= 0x20|SM4_CTR;

	err = sm4_cipher_ctr(req, &cw);

	return err;
}

/*
 *  sm4_ctr_zxc is used for ZXC+
 */
static int sm4_ctr_zxc(struct skcipher_request *req, struct sm4_cipher_data *cw)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct skcipher_walk walk;
	unsigned int nbytes;
	int err;
	u8 *iv = NULL, *dst, *src;
	u8 en_iv[SM4_BLOCK_SIZE] = {0};

	err = skcipher_walk_virt(&walk, req, true);

	while ((nbytes = walk.nbytes) > 0) {

		src = walk.src.virt.addr;
		dst = walk.dst.virt.addr;

		while (nbytes >= SM4_BLOCK_SIZE) {

			iv = rep_xcrypt_ecb_ONE(walk.iv, en_iv, ctx->rkey_enc, walk.iv, cw, 1);
			crypto_inc(walk.iv, SM4_BLOCK_SIZE);

			crypto_xor_cpy(dst, en_iv, src, SM4_BLOCK_SIZE);

			dst += SM4_BLOCK_SIZE;
			src += SM4_BLOCK_SIZE;
			nbytes -= SM4_BLOCK_SIZE;
		}

		// tail
		if (walk.nbytes == walk.total && nbytes > 0) {

			rep_xcrypt_ecb_ONE(walk.iv, en_iv, ctx->rkey_enc, walk.iv, cw, 1);
			crypto_xor_cpy(dst, en_iv, src, nbytes);

			dst += nbytes;
			src += nbytes;
			nbytes = 0;
		}

		err = skcipher_walk_done(&walk, nbytes);
	}

	return err;
}

/*
 * ctr_encrypt_zxc is used for ZX-C+
 */
static int ctr_encrypt_zxc(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 1;
	cw.cword.pad     |= 0x20|SM4_CTR;

	err = sm4_ctr_zxc(req, &cw);

	return err;
}

/*
 * ctr_decrypt_zxc is used for ZX-C+
 */
static int ctr_decrypt_zxc(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 0;
	cw.cword.pad     |= 0x20|SM4_CTR;

	err = sm4_ctr_zxc(req, &cw);

	return err;
}

/*
 *  ofb_encrypt is used for ZX-E and newer
 */
static int ofb_encrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 1;
	cw.cword.pad     |= 0x20|SM4_OFB;

	err = sm4_cipher_common(req, &cw);

	return err;
}

/*
 *  ofb_decrypt is used for ZX-E and newer
 */
static int ofb_decrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad  = 0;
	cw.cword.pad |= 0x20|SM4_OFB;

	err = sm4_cipher_common(req, &cw);

	return err;
}

/*
 * sm4_ofb_zxc is used for ZX-C+
 */
static int sm4_ofb_zxc(struct skcipher_request *req, struct sm4_cipher_data *cw)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct skcipher_walk walk;
	unsigned int blocks;
	int err;

	u32 n;

	err = skcipher_walk_virt(&walk, req, true);

	while ((blocks = (walk.nbytes / SM4_BLOCK_SIZE))) {
		while (blocks--) {

			rep_xcrypt_ecb_ONE(walk.iv, walk.iv, ctx->rkey_enc, NULL, cw, 1);

			for (n = 0; n < SM4_BLOCK_SIZE; n += sizeof(size_t))
				*(size_t *)(walk.dst.virt.addr + n) =
					*(size_t *)(walk.iv + n) ^
					*(size_t *)(walk.src.virt.addr + n);

			walk.src.virt.addr += SM4_BLOCK_SIZE;
			walk.dst.virt.addr += SM4_BLOCK_SIZE;

		}

		err = skcipher_walk_done(&walk, walk.nbytes % SM4_BLOCK_SIZE);
	}

	return err;
}

/*
 *  ofb_encrypt_zxc is used for ZX-C+
 */
static int ofb_encrypt_zxc(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 1;
	cw.cword.pad     |= 0x20|SM4_OFB;

	err = sm4_ofb_zxc(req, &cw);

	return err;
}

/*
 * ofb_decrypt_zxc is used for ZX-C+
 */
static int ofb_decrypt_zxc(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 0;
	cw.cword.pad     |= 0x20|SM4_OFB;

	err = sm4_ofb_zxc(req, &cw);

	return err;
}


/*
 * cfb_encrypt is used for ZX-E and newer.
 */
static int cfb_encrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 1;
	cw.cword.pad     |= 0x20|SM4_CFB;

	err = sm4_cipher_common(req, &cw);

	return err;
}

/*
 * cfb_decrypt is used for ZX-E and newer.
 */

static int cfb_decrypt(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad  = 0;
	cw.cword.pad |= 0x20|SM4_CFB;

	err = sm4_cipher_common(req, &cw);

	return err;

}

/*
 * sm4_cfb_zxc is used for ZX-C+
 */
static int sm4_cfb_zxc(struct skcipher_request *req, struct sm4_cipher_data *cw)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct skcipher_walk walk;
	unsigned int blocks;
	int err;
	u32 n;
	size_t t;

	err = skcipher_walk_virt(&walk, req, true);

	while ((blocks = (walk.nbytes / SM4_BLOCK_SIZE))) {
		while (blocks--) {
			rep_xcrypt_ecb_ONE(walk.iv, walk.iv, ctx->rkey_enc, NULL, cw, 1);

			if (cw->cword.b.encdec)
				for (n = 0; n < SM4_BLOCK_SIZE; n += sizeof(size_t))
					*(size_t *)(walk.dst.virt.addr + n) =
						*(size_t *)(walk.iv + n) ^=
						*(size_t *)(walk.src.virt.addr + n);

			else
				for (n = 0; n < SM4_BLOCK_SIZE; n += sizeof(size_t)) {
					t = *(size_t *)(walk.src.virt.addr + n);
					*(size_t *)(walk.dst.virt.addr + n) =
						*(size_t *)(walk.iv + n) ^ t;
					*(size_t *)(walk.iv + n) = t;
				}

			walk.src.virt.addr += SM4_BLOCK_SIZE;
			walk.dst.virt.addr += SM4_BLOCK_SIZE;
		}

		err = skcipher_walk_done(&walk, walk.nbytes % SM4_BLOCK_SIZE);
	}

	return err;
}

/*
 * cfb_encrypt_zxc is used for ZX-C+
 */
static int cfb_encrypt_zxc(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 1;
	cw.cword.pad     |= 0x20|SM4_CFB;

	err = sm4_cfb_zxc(req, &cw);

	return err;
}

/*
 * cfb_decrypt_zxc is used for ZX-C+
 */
static int cfb_decrypt_zxc(struct skcipher_request *req)
{
	int err;
	struct sm4_cipher_data cw;

	cw.cword.pad      = 0;
	cw.cword.b.encdec = 0;
	cw.cword.pad     |= 0x20|SM4_CFB;

	err = sm4_cfb_zxc(req, &cw);

	return err;
}


static struct skcipher_alg sm4_algs[] = {
	{
		.base = {
			.cra_name           = "__ecb(sm4)",
			.cra_driver_name    = "__ecb-sm4-gmi",
			.cra_priority       = 300,
			.cra_flags          = CRYPTO_ALG_INTERNAL,
			.cra_blocksize      = SM4_BLOCK_SIZE,
			.cra_ctxsize        = sizeof(struct sm4_ctx),
			.cra_module         = THIS_MODULE,
		},
		.min_keysize    = SM4_KEY_SIZE,
		.max_keysize    = SM4_KEY_SIZE,
		.ivsize         = SM4_BLOCK_SIZE,
		.walksize	= 8 * SM4_BLOCK_SIZE,
		.setkey         = gmi_sm4_set_key,
		.encrypt        = ecb_encrypt,
		.decrypt        = ecb_decrypt,
	},

	{
		.base = {
			.cra_name           = "__cbc(sm4)",
			.cra_driver_name    = "__cbc-sm4-gmi",
			.cra_priority       = 300,
			.cra_flags          = CRYPTO_ALG_INTERNAL,
			.cra_blocksize      = SM4_BLOCK_SIZE,
			.cra_ctxsize        = sizeof(struct sm4_ctx),
			.cra_module         = THIS_MODULE,
		},
		.min_keysize    = SM4_KEY_SIZE,
		.max_keysize    = SM4_KEY_SIZE,
		.ivsize         = SM4_BLOCK_SIZE,
		.walksize	= 8 * SM4_BLOCK_SIZE,
		.setkey         = gmi_sm4_set_key,
		.encrypt        = cbc_encrypt,
		.decrypt        = cbc_decrypt,
	},

	{
		.base = {
			.cra_name           = "__ctr(sm4)",
			.cra_driver_name    = "__ctr-sm4-gmi",
			.cra_priority       = 300,
			.cra_flags          = CRYPTO_ALG_INTERNAL,
			.cra_blocksize      = 1, //SM4_BLOCK_SIZE,
			.cra_ctxsize        = sizeof(struct sm4_ctx),
			.cra_module         = THIS_MODULE,
		},
		.min_keysize    = SM4_KEY_SIZE,
		.max_keysize    = SM4_KEY_SIZE,
		.ivsize         = SM4_BLOCK_SIZE,
		.chunksize	= SM4_BLOCK_SIZE,
		.walksize	= 8 * SM4_BLOCK_SIZE,
		.setkey         = gmi_sm4_set_key,
		.encrypt        = ctr_encrypt,
		.decrypt        = ctr_decrypt,
	},

	{
		.base = {
			.cra_name           = "__ofb(sm4)",
			.cra_driver_name    = "__ofb-sm4-gmi",
			.cra_priority       = 300,
			.cra_flags          = CRYPTO_ALG_INTERNAL,
			.cra_blocksize      = SM4_BLOCK_SIZE,
			.cra_ctxsize        = sizeof(struct sm4_ctx),
			.cra_module         = THIS_MODULE,
		},
		.min_keysize    = SM4_KEY_SIZE,
		.max_keysize    = SM4_KEY_SIZE,
		.ivsize         = SM4_BLOCK_SIZE,
		.chunksize	= SM4_BLOCK_SIZE,
		.walksize	= 8 * SM4_BLOCK_SIZE,
		.setkey         = gmi_sm4_set_key,
		.encrypt        = ofb_encrypt,
		.decrypt        = ofb_decrypt,
	},

	{
		.base = {
			.cra_name           = "__cfb(sm4)",
			.cra_driver_name    = "__cfb-sm4-gmi",
			.cra_priority       = 300,
			.cra_flags          = CRYPTO_ALG_INTERNAL,
			.cra_blocksize      = SM4_BLOCK_SIZE,
			.cra_ctxsize        = sizeof(struct sm4_ctx),
			.cra_module         = THIS_MODULE,
		},
		.min_keysize    = SM4_KEY_SIZE,
		.max_keysize    = SM4_KEY_SIZE,
		.ivsize         = SM4_BLOCK_SIZE,
		.chunksize	= SM4_BLOCK_SIZE,
		.walksize	= 8 * SM4_BLOCK_SIZE,
		.setkey         = gmi_sm4_set_key,
		.encrypt        = cfb_encrypt,
		.decrypt        = cfb_decrypt,
	}
};

static struct simd_skcipher_alg *sm4_simd_algs[ARRAY_SIZE(sm4_algs)];

static int gmi_zxc_check(void)
{
	int f_zxc = 0;

	struct cpuinfo_x86 *c = &cpu_data(0);

	if ((c->x86 > 6)) {
		f_zxc = 0;
	} else if (((c->x86 == 6) && (c->x86_model >= 0x0f))
		|| ((c->x86 == 6) && (c->x86_model == 0x09))
		) {
		f_zxc = 1;
	}

	return f_zxc;
}

/*
 * Load supported features of the CPU to see if the SM4 is available.
 */
static int gmi_ccs_available(void)
{
	struct cpuinfo_x86 *c = &cpu_data(0);
	u32 eax, edx;

	if (((c->x86 == 6) && (c->x86_model >= 0x0f))
		|| ((c->x86 == 6) && (c->x86_model == 0x09))
		|| (c->x86 > 6)) {
		if (!boot_cpu_has(X86_FEATURE_CCS) || !boot_cpu_has(X86_FEATURE_CCS_EN)) {

			eax = 0xC0000001;
			__asm__ __volatile__ ("cpuid":"=d"(edx):"a"(eax) : );

			if ((edx & 0x0030) != 0x0030)
				return -ENODEV;

			pr_notice("GMI SM4 is detected by CPUID\n");
			return 0;
		}
		pr_notice("GMI SM4 is available\n");
		return 0;

	}
	return -ENODEV;
}


static void gmi_sm4_exit(void)
{
	int i;

	for (i = 0; i < ARRAY_SIZE(sm4_simd_algs) && sm4_simd_algs[i]; i++)
		simd_skcipher_free(sm4_simd_algs[i]);

	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
}
static int __init gmi_sm4_init(void)
{
	struct simd_skcipher_alg *simd;
	const char *basename;
	const char *algname;
	const char *drvname;
	int err;
	int i;

	if (gmi_ccs_available() != 0)
		return -ENODEV;

	if (gmi_zxc_check()) {

		for (i = 0; i < ARRAY_SIZE(sm4_algs); i++) {
			if (!strcmp(sm4_algs[i].base.cra_name, "__ctr(sm4)")) {

				sm4_algs[i].encrypt = ctr_encrypt_zxc;
				sm4_algs[i].decrypt = ctr_decrypt_zxc;
			} else if (!strcmp(sm4_algs[i].base.cra_name, "__cfb(sm4)")) {

				sm4_algs[i].encrypt = cfb_encrypt_zxc;
				sm4_algs[i].decrypt = cfb_decrypt_zxc;

			} else if (!strcmp(sm4_algs[i].base.cra_name, "__ofb(sm4)")) {

				sm4_algs[i].encrypt = ofb_encrypt_zxc;
				sm4_algs[i].decrypt = ofb_decrypt_zxc;
			}
		}
	}

	err = crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
	if (err)
		return err;

	for (i = 0; i < ARRAY_SIZE(sm4_algs); i++) {
		algname = sm4_algs[i].base.cra_name + 2;
		drvname = sm4_algs[i].base.cra_driver_name + 2;
		basename = sm4_algs[i].base.cra_driver_name;
		simd = simd_skcipher_create_compat(algname, drvname, basename);
		err = PTR_ERR(simd);
		if (IS_ERR(simd))
			goto unregister_simds;

		sm4_simd_algs[i] = simd;
	}

	return 0;

unregister_simds:
	gmi_sm4_exit();
	return err;
}

late_initcall(gmi_sm4_init);
module_exit(gmi_sm4_exit);

MODULE_DESCRIPTION("SM4-ECB/CBC/CTR/CFB/OFB using Zhaoxin GMI");
MODULE_AUTHOR("GRX");
MODULE_LICENSE("GPL");