From c57b4cad76c85b9566a1c264fb9bae3c1bddfab3 Mon Sep 17 00:00:00 2001 From: yuzexi Date: Tue, 22 Oct 2024 11:37:56 +0800 Subject: [PATCH] driver: crypto: hisilicon: add rsa algorithm add rsa algorithm Signed-off-by: yuzexi --- core/drivers/crypto/hisilicon/crypto.mk | 1 + core/drivers/crypto/hisilicon/hpre_rsa.c | 1398 ++++++++++++++++++++++ core/drivers/crypto/hisilicon/hpre_rsa.h | 47 + core/drivers/crypto/hisilicon/sub.mk | 1 + 4 files changed, 1447 insertions(+) create mode 100644 core/drivers/crypto/hisilicon/hpre_rsa.c create mode 100644 core/drivers/crypto/hisilicon/hpre_rsa.h diff --git a/core/drivers/crypto/hisilicon/crypto.mk b/core/drivers/crypto/hisilicon/crypto.mk index 6b7cc7ca25e..e5148a024d5 100644 --- a/core/drivers/crypto/hisilicon/crypto.mk +++ b/core/drivers/crypto/hisilicon/crypto.mk @@ -11,6 +11,7 @@ $(call force,CFG_CRYPTO_DRV_AUTHENC,y,Mandated by CFG_HISILICON_CRYPTO_DRIVER) ifeq ($(CFG_HISILICON_ACC_V3), y) $(call force, CFG_CRYPTO_DRV_DH,y,Mandated by CFG_HISILICON_ACC_V3) $(call force,CFG_CRYPTO_DRV_ECC,y,Mandated by CFG_HISILICON_ACC_V3) +$(call force,CFG_CRYPTO_DRV_RSA,y,Mandated by CFG_HISILICON_ACC_V3) endif endif diff --git a/core/drivers/crypto/hisilicon/hpre_rsa.c b/core/drivers/crypto/hisilicon/hpre_rsa.c new file mode 100644 index 00000000000..63f3f685851 --- /dev/null +++ b/core/drivers/crypto/hisilicon/hpre_rsa.c @@ -0,0 +1,1398 @@ +// SPDX-License-Identifier: BSD-2-Clause +/* + * Copyright (c) 2022-2024, HiSilicon Technologies Co., Ltd. + * Kunpeng hardware accelerator hpre rsa algorithm implementation. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hpre_main.h" +#include "hpre_rsa.h" + +static enum hisi_drv_status hpre_rsa_fill_addr_params(struct hpre_rsa_msg *msg, + struct hpre_sqe *sqe) +{ + switch (msg->alg_type) { + case HPRE_ALG_KG_CRT: /* KEY GEN */ + sqe->key = msg->in_dma; + sqe->out = msg->out_dma; + return HISI_QM_DRVCRYPT_NO_ERR; + case HPRE_ALG_NC_NCRT: + case HPRE_ALG_NC_CRT: + if (msg->is_private) { + /* DECRYPT */ + sqe->key = msg->prikey_dma; + sqe->in = msg->in_dma; + sqe->out = msg->out_dma; + } else { + /* ENCRYPT */ + sqe->key = msg->pubkey_dma; + sqe->in = msg->in_dma; + sqe->out = msg->out_dma; + } + return HISI_QM_DRVCRYPT_NO_ERR; + default: + EMSG("Invalid alg_type[%"PRIu32"]", msg->alg_type); + return HISI_QM_DRVCRYPT_IN_EPARA; + } +} + +static enum hisi_drv_status hpre_rsa_fill_sqe(void *bd, void *info) +{ + struct hpre_rsa_msg *msg = info; + struct hpre_sqe *sqe = bd; + + sqe->w0 = msg->alg_type | SHIFT_U32(0x1, HPRE_DONE_SHIFT); + sqe->task_len1 = TASK_LENGTH(msg->key_bytes); + + return hpre_rsa_fill_addr_params(msg, sqe); +} + +static enum hisi_drv_status hpre_rsa_parse_sqe(void *bd, void *info __unused) +{ + struct hpre_sqe *sqe = bd; + uint16_t err = 0; + uint16_t done = 0; + + err = HPRE_TASK_ETYPE(sqe->w0); + done = HPRE_TASK_DONE(sqe->w0); + if (done != HPRE_HW_TASK_DONE || err) { + EMSG("HPRE do rsa fail! done=0x%"PRIX16", etype=0x%"PRIX16, + done, err); + return HISI_QM_DRVCRYPT_IN_EPARA; + } + + return HISI_QM_DRVCRYPT_NO_ERR; +} + +static TEE_Result hpre_rsa_do_task(void *msg) +{ + struct hisi_qp *rsa_qp = NULL; + TEE_Result res = TEE_SUCCESS; + enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; + + rsa_qp = hpre_create_qp(HISI_QM_CHANNEL_TYPE0); + if (!rsa_qp) { + EMSG("Fail to create rsa qp"); + return TEE_ERROR_BUSY; + } + + rsa_qp->fill_sqe = hpre_rsa_fill_sqe; + rsa_qp->parse_sqe = hpre_rsa_parse_sqe; + ret = hisi_qp_send(rsa_qp, msg); + if (ret) { + EMSG("Fail to send task, ret=%d", ret); + res = TEE_ERROR_BAD_STATE; + goto done_proc; + } + + ret = hisi_qp_recv_sync(rsa_qp, msg); + if (ret) { + EMSG("Recv task error, ret=%d", ret); + res = TEE_ERROR_BAD_STATE; + goto done_proc; + } + +done_proc: + hisi_qm_release_qp(rsa_qp); + + return res; +} + +static int32_t hpre_rsa_get_random_data(void *prng __unused, + uint8_t *buf, size_t len) +{ + if (hw_get_random_bytes(buf, len)) { + EMSG("Fail to get random data"); + return MBEDTLS_ERR_MPI_FILE_IO_ERROR; + } + + return 0; +} + +static void hpre_rsa_fill_p_q(struct rsa_keypair *key, mbedtls_mpi *p, + mbedtls_mpi *q) +{ + crypto_bignum_copy(key->p, (struct bignum *)p); + crypto_bignum_copy(key->q, (struct bignum *)q); +} + +static int32_t hpre_rsa_gen_p_q(size_t nbits, struct rsa_keypair *key) +{ + int32_t prime_quality = 0; + mbedtls_rsa_context rsa = { }; + mbedtls_mpi H = { }; + mbedtls_mpi G = { }; + mbedtls_mpi L = { }; + uint32_t e = 0; + + memset(&rsa, 0, sizeof(rsa)); + mbedtls_mpi_init(&H); + mbedtls_mpi_init(&G); + mbedtls_mpi_init(&L); + + if (nbits > PRIME_QUALITY_FLAG) + prime_quality = MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR; + + mbedtls_mpi_write_binary((mbedtls_mpi *)key->e, (uint8_t *)&e, + sizeof(uint32_t)); + e = TEE_U32_FROM_BIG_ENDIAN(e); + mbedtls_mpi_lset(&rsa.E, (int32_t)e); + + while (true) { + MBEDTLS_MPI_CHK(mbedtls_mpi_gen_prime(&rsa.P, nbits >> 1, + prime_quality, + hpre_rsa_get_random_data, + NULL)); + + MBEDTLS_MPI_CHK(mbedtls_mpi_gen_prime(&rsa.Q, nbits >> 1, + prime_quality, + hpre_rsa_get_random_data, + NULL)); + + /* Temporarily replace P,Q by P-1, Q-1 */ + MBEDTLS_MPI_CHK(mbedtls_mpi_sub_int(&rsa.P, &rsa.P, 1)); + MBEDTLS_MPI_CHK(mbedtls_mpi_sub_int(&rsa.Q, &rsa.Q, 1)); + MBEDTLS_MPI_CHK(mbedtls_mpi_mul_mpi(&H, &rsa.P, &rsa.Q)); + + /* check GCD( E, (P-1)*(Q-1) ) == 1 */ + MBEDTLS_MPI_CHK(mbedtls_mpi_gcd(&G, &rsa.E, &H)); + if (mbedtls_mpi_cmp_int(&G, 1)) + continue; + + MBEDTLS_MPI_CHK(mbedtls_mpi_gcd(&G, &rsa.P, &rsa.Q)); + MBEDTLS_MPI_CHK(mbedtls_mpi_div_mpi(&L, NULL, &H, &G)); + MBEDTLS_MPI_CHK(mbedtls_mpi_inv_mod(&rsa.D, &rsa.E, &L)); + + if (mbedtls_mpi_bitlen(&rsa.D) > ((nbits + 1) >> 1)) + break; + } + + /* Restore P,Q */ + MBEDTLS_MPI_CHK(mbedtls_mpi_add_int(&rsa.P, &rsa.P, 1)); + MBEDTLS_MPI_CHK(mbedtls_mpi_add_int(&rsa.Q, &rsa.Q, 1)); + + hpre_rsa_fill_p_q(key, &rsa.P, &rsa.Q); + +cleanup: + mbedtls_mpi_free(&H); + mbedtls_mpi_free(&G); + mbedtls_mpi_free(&L); + + return 0; +} + +static TEE_Result mgf_process(size_t digest_size, uint8_t *seed, + size_t seed_len, uint8_t *mask, size_t mask_len, + struct drvcrypt_rsa_ed *rsa_data) +{ + struct drvcrypt_rsa_mgf mgf = { }; + + if (!rsa_data->mgf) { + EMSG("mgf function is NULL"); + return TEE_ERROR_BAD_PARAMETERS; + } + + mgf.hash_algo = rsa_data->hash_algo; + mgf.digest_size = digest_size; + mgf.seed.data = seed; + mgf.seed.length = seed_len; + mgf.mask.data = mask; + mgf.mask.length = mask_len; + + return rsa_data->mgf(&mgf); +} + +static TEE_Result xor_process(uint8_t *a, uint8_t *b, uint8_t *out, size_t len) +{ + struct drvcrypt_mod_op xor_mod = { }; + + xor_mod.n.length = len; + xor_mod.a.data = a; + xor_mod.a.length = len; + xor_mod.b.data = b; + xor_mod.b.length = len; + xor_mod.result.data = out; + xor_mod.result.length = len; + + return drvcrypt_xor_mod_n(&xor_mod); +} + +static void hpre_rsa_free_keypair(struct rsa_keypair *key) +{ + if (key) { + crypto_bignum_free(&key->e); + crypto_bignum_free(&key->d); + crypto_bignum_free(&key->n); + crypto_bignum_free(&key->p); + crypto_bignum_free(&key->q); + crypto_bignum_free(&key->qp); + crypto_bignum_free(&key->dp); + crypto_bignum_free(&key->dq); + } +} + +static TEE_Result hpre_rsa_allocate_keypair(struct rsa_keypair *key, + size_t size_bits) +{ + if (!key || !size_bits) { + EMSG("Invalid input parameter"); + return TEE_ERROR_BAD_PARAMETERS; + } + + memset(key, 0, sizeof(*key)); + key->e = crypto_bignum_allocate(size_bits); + if (!key->e) + goto alloc_fail; + + key->d = crypto_bignum_allocate(size_bits); + if (!key->d) + goto alloc_fail; + + key->n = crypto_bignum_allocate(size_bits); + if (!key->n) + goto alloc_fail; + + key->p = crypto_bignum_allocate(size_bits >> 1); + if (!key->p) + goto alloc_fail; + + key->q = crypto_bignum_allocate(size_bits >> 1); + if (!key->q) + goto alloc_fail; + + key->dp = crypto_bignum_allocate(size_bits >> 1); + if (!key->dp) + goto alloc_fail; + + key->dq = crypto_bignum_allocate(size_bits >> 1); + if (!key->dq) + goto alloc_fail; + + key->qp = crypto_bignum_allocate(size_bits >> 1); + if (!key->qp) + goto alloc_fail; + + return TEE_SUCCESS; + +alloc_fail: + EMSG("HPRE rsa alloc key pair fail"); + hpre_rsa_free_keypair(key); + + return TEE_ERROR_OUT_OF_MEMORY; +} + +static TEE_Result hpre_rsa_allocate_publickey(struct rsa_public_key *key, + size_t size_bits) +{ + if (!key) { + EMSG("Invalid input parameter"); + return TEE_ERROR_BAD_PARAMETERS; + } + + memset(key, 0, sizeof(*key)); + key->e = crypto_bignum_allocate(size_bits); + if (!key->e) + return TEE_ERROR_OUT_OF_MEMORY; + + key->n = crypto_bignum_allocate(size_bits); + if (!key->n) { + crypto_bignum_free(&key->e); + return TEE_ERROR_OUT_OF_MEMORY; + } + + return TEE_SUCCESS; +} + +static void hpre_rsa_free_publickey(struct rsa_public_key *key) +{ + if (key) { + crypto_bignum_free(&key->e); + crypto_bignum_free(&key->n); + } +} + +static TEE_Result hpre_rsa_get_keypair_result(struct hpre_rsa_msg *msg, + struct rsa_keypair *key) +{ + uint8_t *d = NULL; + uint8_t *n = NULL; + uint8_t *qp = NULL; + uint8_t *dq = NULL; + uint8_t *dp = NULL; + TEE_Result ret = TEE_SUCCESS; + + d = msg->out; + n = msg->out + msg->key_bytes; + qp = n + msg->key_bytes; + dq = qp + (msg->key_bytes >> 1); + dp = dq + (msg->key_bytes >> 1); + + ret = crypto_bignum_bin2bn(d, msg->key_bytes, key->d); + if (ret) { + EMSG("Fail to bin2bn param d"); + return ret; + } + + ret = crypto_bignum_bin2bn(n, msg->key_bytes, key->n); + if (ret) { + EMSG("Fail to bin2bn param n"); + return ret; + } + + ret = crypto_bignum_bin2bn(qp, msg->key_bytes >> 1, key->qp); + if (ret) { + EMSG("Fail to bin2bn param qp"); + return ret; + } + + ret = crypto_bignum_bin2bn(dq, msg->key_bytes >> 1, key->dq); + if (ret) { + EMSG("Fail to bin2bn param dq"); + return ret; + } + + ret = crypto_bignum_bin2bn(dp, msg->key_bytes >> 1, key->dp); + if (ret) + EMSG("Fail to bin2bn param dp"); + + return ret; +} + +static size_t hpre_rsa_get_hw_kbytes(size_t key_bits) +{ + size_t size = 0; + + if (key_bits <= 1024) + size = BITS_TO_BYTES(1024); + else if (key_bits <= 2048) + size = BITS_TO_BYTES(2048); + else if (key_bits <= 3072) + size = BITS_TO_BYTES(3072); + else if (key_bits <= 4096) + size = BITS_TO_BYTES(4096); + else + EMSG("Invalid key_bits[%zu]", key_bits); + + return size; +} + +static void hpre_rsa_params_free(struct hpre_rsa_msg *msg) +{ + size_t len = 0; + + switch (msg->alg_type) { + case HPRE_ALG_KG_CRT: + free_wipe(msg->in); + break; + case HPRE_ALG_NC_NCRT: + if (msg->is_private) { + free_wipe(msg->prikey); + } else { + free(msg->pubkey); + } + break; + case HPRE_ALG_NC_CRT: + if (msg->is_private) { + len = HPRE_RSA_CRT_KEY_BUF_SIZE(msg->key_bytes); + free_wipe(msg->prikey); + } + break; + default: + EMSG("Invalid alg_type[%"PRIu32"]", msg->alg_type); + break; + } +} + +static enum hisi_drv_status hpre_rsa_keypair_alloc(struct hpre_rsa_msg *msg) +{ + uint32_t size = HPRE_RSA_GEN_TOTAL_BUF_SIZE(msg->key_bytes); + uint8_t *data = NULL; + + data = calloc(1, size); + if (!data) { + EMSG("Fail to alloc rsa gen buf"); + return HISI_QM_DRVCRYPT_ENOMEM; + } + + msg->in = data; + msg->in_dma = virt_to_phys(msg->in); + + msg->out = data + (msg->key_bytes << 1); + msg->out_dma = msg->in_dma + (msg->key_bytes << 1); + + return HISI_QM_DRVCRYPT_NO_ERR; +} + +static enum hisi_drv_status hpre_rsa_keypair_bn2bin(struct hpre_rsa_msg *msg, + struct rsa_keypair *key) +{ + uint32_t e_len = 0; + uint32_t p_len = 0; + uint32_t q_len = 0; + uint8_t *e = NULL; + uint8_t *p = NULL; + uint8_t *q = NULL; + enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; + + e = msg->in; + p = msg->in + msg->key_bytes; + q = p + (msg->key_bytes >> 1); + + crypto_bignum_bn2bin(key->e, e); + crypto_bignum_bn2bin(key->p, p); + crypto_bignum_bn2bin(key->q, q); + e_len = crypto_bignum_num_bytes(key->e); + p_len = crypto_bignum_num_bytes(key->p); + q_len = crypto_bignum_num_bytes(key->q); + + ret = hpre_bin_from_crypto_bin(e, e, msg->key_bytes, e_len); + if (ret) { + EMSG("Fail to transfer rsa e from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(p, p, msg->key_bytes >> 1, p_len); + if (ret) { + EMSG("Fail to transfer rsa p from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(q, q, msg->key_bytes >> 1, q_len); + if (ret) + EMSG("Fail to transfer rsa q from crypto_bin to hpre_bin"); + + return ret; +} + +static TEE_Result hpre_rsa_keypair_init(struct hpre_rsa_msg *msg, + struct rsa_keypair *key, + size_t key_bits) +{ + enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; + + msg->alg_type = HPRE_ALG_KG_CRT; + msg->key_bytes = hpre_rsa_get_hw_kbytes(key_bits); + if (!msg->key_bytes) + return TEE_ERROR_BAD_PARAMETERS; + + ret = hpre_rsa_keypair_alloc(msg); + if (ret) + return TEE_ERROR_BAD_STATE; + + ret = hpre_rsa_keypair_bn2bin(msg, key); + if (ret) { + hpre_rsa_params_free(msg); + return TEE_ERROR_BAD_STATE; + } + + return TEE_SUCCESS; +} + +static TEE_Result hpre_rsa_gen_keypair(struct rsa_keypair *key, size_t key_bits) +{ + struct hpre_rsa_msg msg = { }; + TEE_Result ret = TEE_SUCCESS; + + if (!key || !key->n || !key->e || !key->d || !key->p || !key->q || + !key->dp || !key->dq || !key->qp) { + EMSG("Invalid rsa gen key pair input parameter"); + return TEE_ERROR_BAD_PARAMETERS; + } + + if (!crypto_bignum_num_bits(key->e)) { + EMSG("Invalid parameter e"); + return TEE_ERROR_BAD_PARAMETERS; + } + + if (!crypto_bignum_num_bits(key->p) || + !crypto_bignum_num_bits(key->q)) { + if (hpre_rsa_gen_p_q(key_bits, key)) { + EMSG("Fail to gen prime p and q"); + return TEE_ERROR_BAD_PARAMETERS; + } + } + + ret = hpre_rsa_keypair_init(&msg, key, key_bits); + if (ret) { + EMSG("Fail to init rsa msg"); + return ret; + } + + ret = hpre_rsa_do_task(&msg); + if (ret) + goto free_all; + + ret = hpre_rsa_get_keypair_result(&msg, key); + +free_all: + hpre_rsa_params_free(&msg); + + return ret; +} + +static enum hisi_drv_status hpre_rsa_encrypt_alloc(struct hpre_rsa_msg *msg) +{ + uint32_t size = HPRE_RSA_NCRT_TOTAL_BUF_SIZE(msg->key_bytes); + uint8_t *data = NULL; + + data = calloc(1, size); + if (!data) { + EMSG("Fail to alloc rsa ncrt buf"); + return HISI_QM_DRVCRYPT_ENOMEM; + } + + msg->pubkey = data; + msg->pubkey_dma = virt_to_phys(msg->pubkey); + + msg->in = data + (msg->key_bytes << 1); + msg->in_dma = msg->pubkey_dma + (msg->key_bytes << 1); + + msg->out = msg->in + msg->key_bytes; + msg->out_dma = msg->in_dma + msg->key_bytes; + + return HISI_QM_DRVCRYPT_NO_ERR; +} + +static enum hisi_drv_status +hpre_rsa_encrypt_bn2bin(struct hpre_rsa_msg *msg, + struct drvcrypt_rsa_ed *rsa_data) +{ + struct rsa_public_key *key = rsa_data->key.key; + uint32_t e_len = 0; + uint32_t n_len = 0; + enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; + uint8_t *n = NULL; + + n = msg->pubkey + msg->key_bytes; + + crypto_bignum_bn2bin(key->e, msg->pubkey); + crypto_bignum_bn2bin(key->n, n); + e_len = crypto_bignum_num_bytes(key->e); + n_len = crypto_bignum_num_bytes(key->n); + + ret = hpre_bin_from_crypto_bin(msg->pubkey, msg->pubkey, + msg->key_bytes, e_len); + if (ret) { + EMSG("Fail to transfer rsa ncrt e from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(n, n, msg->key_bytes, n_len); + if (ret) { + EMSG("Fail to transfer rsa ncrt n from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->message.data, + msg->key_bytes, + rsa_data->message.length); + if (ret) + EMSG("Fail to transfer rsa plaintext from crypto_bin to hpre_bin"); + + return ret; +} + +static TEE_Result hpre_rsa_encrypt_init(struct hpre_rsa_msg *msg, + struct drvcrypt_rsa_ed *rsa_data) +{ + size_t n_bytes = rsa_data->key.n_size; + enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; + + msg->alg_type = HPRE_ALG_NC_NCRT; + msg->is_private = rsa_data->key.isprivate; + msg->key_bytes = hpre_rsa_get_hw_kbytes(BYTES_TO_BITS(n_bytes)); + if (!msg->key_bytes) + return TEE_ERROR_BAD_PARAMETERS; + + ret = hpre_rsa_encrypt_alloc(msg); + if (ret) + return TEE_ERROR_BAD_STATE; + + ret = hpre_rsa_encrypt_bn2bin(msg, rsa_data); + if (ret) { + hpre_rsa_params_free(msg); + return TEE_ERROR_BAD_STATE; + } + + return TEE_SUCCESS; +} + +static TEE_Result rsa_nopad_encrypt(struct drvcrypt_rsa_ed *rsa_data) +{ + size_t n_bytes = rsa_data->key.n_size; + struct hpre_rsa_msg msg = { }; + TEE_Result ret = TEE_SUCCESS; + + if (rsa_data->message.length > n_bytes) { + EMSG("Invalid msg length[%zu]", rsa_data->message.length); + return TEE_ERROR_BAD_PARAMETERS; + } + + ret = hpre_rsa_encrypt_init(&msg, rsa_data); + if (ret) { + EMSG("Fail to init rsa msg"); + return ret; + } + + ret = hpre_rsa_do_task(&msg); + if (ret) + goto encrypt_deinit; + + /* Ciphertext can have valid zero data in NOPAD MODE */ + memcpy(rsa_data->cipher.data, msg.out + msg.key_bytes - n_bytes, + n_bytes); + rsa_data->cipher.length = n_bytes; + +encrypt_deinit: + hpre_rsa_params_free(&msg); + + return ret; +} + +static TEE_Result pkcs_v1_5_fill_ps(uint8_t *ps, size_t ps_len) +{ + size_t i = 0; + + if (hw_get_random_bytes(ps, ps_len)) { + EMSG("Fail to get ps data"); + return TEE_ERROR_NO_DATA; + } + + for (i = 0; i < ps_len; i++) { + if (ps[i] == 0) + ps[i] = PKCS_V1_5_PS_DATA; + } + + return TEE_SUCCESS; +} + +static TEE_Result rsaes_pkcs_v1_5_encode(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *out) +{ + size_t msg_len = rsa_data->message.length; + size_t out_len = rsa_data->cipher.length; + size_t n_bytes = rsa_data->key.n_size; + uint8_t *ps = out + PKCS_V1_5_PS_POS; + TEE_Result ret = TEE_SUCCESS; + size_t ps_len = 0; + + /* PKCS_V1.5 format 0x00 || 0x02 || PS non-zero || 0x00 || M */ + if ((msg_len + PKCS_V1_5_MSG_MIN_LEN) > n_bytes || out_len < n_bytes) { + EMSG("Invalid pkcs_v1.5 encode parameter"); + return TEE_ERROR_BAD_PARAMETERS; + } + + ps_len = n_bytes - PKCS_V1_5_FIXED_LEN - msg_len; + ret = pkcs_v1_5_fill_ps(ps, ps_len); + if (ret) + return ret; + + out[0] = 0; + out[1] = ENCRYPT_PAD; + out[PKCS_V1_5_FIXED_LEN + ps_len - 1] = 0; + memcpy(out + PKCS_V1_5_FIXED_LEN + ps_len, rsa_data->message.data, + msg_len); + + return TEE_SUCCESS; +} + +static TEE_Result rsa_pkcs_encrypt(struct drvcrypt_rsa_ed *rsa_data) +{ + uint32_t n_bytes = rsa_data->key.n_size; + struct drvcrypt_rsa_ed rsa_enc_info = *rsa_data; + TEE_Result ret = TEE_SUCCESS; + + /* Alloc pkcs_v1.5 encode message data buf */ + rsa_enc_info.message.data = malloc(n_bytes); + if (!rsa_enc_info.message.data) { + EMSG("Fail to alloc message data buf"); + return TEE_ERROR_OUT_OF_MEMORY; + } + + rsa_enc_info.message.length = n_bytes; + ret = rsaes_pkcs_v1_5_encode(rsa_data, rsa_enc_info.message.data); + if (ret) { + EMSG("Fail to get pkcs_v1.5 encode message data"); + goto free_data; + } + + ret = rsa_nopad_encrypt(&rsa_enc_info); + if (ret) + goto free_data; + + memcpy(rsa_data->cipher.data, rsa_enc_info.cipher.data, + rsa_enc_info.cipher.length); + rsa_data->cipher.length = rsa_enc_info.cipher.length; + +free_data: + free(rsa_enc_info.message.data); + + return ret; +} + +static TEE_Result rsa_oaep_fill_db(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *db) +{ + size_t lhash_len = rsa_data->digest_size; + size_t n_bytes = rsa_data->key.n_size; + size_t db_len = n_bytes - lhash_len - 1; + size_t ps_len = 0; + TEE_Result ret = TEE_SUCCESS; + + /* oaep db format lhash || ps zero || 01 || M */ + ret = tee_hash_createdigest(rsa_data->hash_algo, rsa_data->label.data, + rsa_data->label.length, db, lhash_len); + if (ret) { + EMSG("Fail to get label hash"); + return ret; + } + + ps_len = db_len - lhash_len - rsa_data->message.length - 1; + db[lhash_len + ps_len] = 1; + memcpy(db + lhash_len + ps_len + 1, rsa_data->message.data, + rsa_data->message.length); + + return TEE_SUCCESS; +} + +static TEE_Result rsa_oaep_fill_maskdb(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *seed, uint8_t *db, + uint8_t *mask_db) +{ + size_t lhash_len = rsa_data->digest_size; + size_t n_bytes = rsa_data->key.n_size; + size_t db_len = n_bytes - lhash_len - 1; + uint8_t seed_mgf[OAEP_MAX_DB_LEN] = { }; + TEE_Result ret = TEE_SUCCESS; + + ret = mgf_process(lhash_len, seed, lhash_len, seed_mgf, db_len, + rsa_data); + if (ret) { + EMSG("Fail to get seed_mgf"); + return ret; + } + + return xor_process(db, seed_mgf, mask_db, db_len); +} + +static TEE_Result rsa_oaep_fill_maskseed(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *seed, uint8_t *em) +{ + uint8_t mask_db_mgf[OAEP_MAX_HASH_LEN] = { 0 }; + size_t lhash_len = rsa_data->digest_size; + size_t n_bytes = rsa_data->key.n_size; + size_t db_len = n_bytes - lhash_len - 1; + uint8_t *mask_db = em + lhash_len + 1; + uint8_t *mask_seed = em + 1; + TEE_Result ret = TEE_SUCCESS; + + ret = mgf_process(lhash_len, mask_db, db_len, mask_db_mgf, lhash_len, + rsa_data); + if (ret) { + EMSG("Fail to get mask_db_mgf"); + return ret; + } + + return xor_process(seed, mask_db_mgf, mask_seed, lhash_len); +} + +static TEE_Result rsa_oaep_encode(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *em) +{ + size_t lhash_len = rsa_data->digest_size; + uint8_t db[OAEP_MAX_DB_LEN] = { }; + uint8_t seed[OAEP_MAX_HASH_LEN]; + TEE_Result ret = TEE_SUCCESS; + + /* oaep format 00 || maskedseed || maskeddb */ + em[0] = 0; + + ret = rsa_oaep_fill_db(rsa_data, db); + if (ret) + return ret; + + ret = hw_get_random_bytes(seed, lhash_len); + if (ret) + return ret; + + ret = rsa_oaep_fill_maskdb(rsa_data, seed, db, em + lhash_len + 1); + if (ret) + return ret; + + return rsa_oaep_fill_maskseed(rsa_data, seed, em); +} + +static TEE_Result rsa_oaep_encrypt(struct drvcrypt_rsa_ed *rsa_data) +{ + size_t n_bytes = rsa_data->key.n_size; + struct drvcrypt_rsa_ed rsa_enc_info = *rsa_data; + TEE_Result ret = TEE_SUCCESS; + + /* Alloc oaep encode message data buf */ + rsa_enc_info.message.data = malloc(n_bytes); + if (!rsa_enc_info.message.data) { + EMSG("Fail to alloc message data buf"); + return TEE_ERROR_OUT_OF_MEMORY; + } + + rsa_enc_info.message.length = n_bytes; + ret = rsa_oaep_encode(rsa_data, rsa_enc_info.message.data); + if (ret) { + EMSG("Fail to get oaep encode message data"); + goto free_data; + } + + ret = rsa_nopad_encrypt(&rsa_enc_info); + if (ret) + goto free_data; + + memcpy(rsa_data->cipher.data, rsa_enc_info.cipher.data, + rsa_enc_info.cipher.length); + rsa_data->cipher.length = rsa_enc_info.cipher.length; + +free_data: + free(rsa_enc_info.message.data); + + return ret; +} + +static TEE_Result hpre_rsa_encrypt(struct drvcrypt_rsa_ed *rsa_data) +{ + if (!rsa_data) { + EMSG("Invalid rsa encrypt input parameter"); + return TEE_ERROR_BAD_PARAMETERS; + } + + switch (rsa_data->rsa_id) { + case DRVCRYPT_RSA_NOPAD: + case DRVCRYPT_RSASSA_PKCS_V1_5: + case DRVCRYPT_RSASSA_PSS: + return rsa_nopad_encrypt(rsa_data); + case DRVCRYPT_RSA_PKCS_V1_5: + return rsa_pkcs_encrypt(rsa_data); + case DRVCRYPT_RSA_OAEP: + return rsa_oaep_encrypt(rsa_data); + default: + EMSG("Invalid rsa id"); + return TEE_ERROR_BAD_PARAMETERS; + } +} + +static enum hisi_drv_status hpre_rsa_crt_decrypt_alloc(struct hpre_rsa_msg *msg) +{ + uint32_t size = HPRE_RSA_CRT_TOTAL_BUF_SIZE(msg->key_bytes); + uint8_t *data = NULL; + + data = calloc(1, size); + if (!data) { + EMSG("Fail to alloc rsa crt total buf"); + return HISI_QM_DRVCRYPT_ENOMEM; + } + + msg->prikey = data; + msg->prikey_dma = virt_to_phys(msg->prikey); + if (!msg->prikey_dma) { + EMSG("Fail to get prikey dma addr"); + free(data); + return HISI_QM_DRVCRYPT_EFAULT; + } + + msg->in = data + (msg->key_bytes << 1) + (msg->key_bytes >> 1); + msg->in_dma = msg->prikey_dma + (msg->key_bytes << 1) + + (msg->key_bytes >> 1); + + msg->out = msg->in + msg->key_bytes; + msg->out_dma = msg->in_dma + msg->key_bytes; + + return HISI_QM_DRVCRYPT_NO_ERR; +} + +static enum hisi_drv_status hpre_rsa_ncrt_decrypt_alloc(struct hpre_rsa_msg *msg) +{ + uint32_t size = HPRE_RSA_NCRT_TOTAL_BUF_SIZE(msg->key_bytes); + uint8_t *data = NULL; + + data = calloc(1, size); + if (!data) { + EMSG("Fail to alloc rsa ncrt buf"); + return HISI_QM_DRVCRYPT_ENOMEM; + } + + msg->prikey = data; + msg->prikey_dma = virt_to_phys(msg->prikey); + if (!msg->prikey_dma) { + EMSG("Fail to get prikey dma addr"); + free(data); + return HISI_QM_DRVCRYPT_EFAULT; + } + + msg->in = data + (msg->key_bytes << 1); + msg->in_dma = msg->prikey_dma + (msg->key_bytes << 1); + + msg->out = msg->in + msg->key_bytes; + msg->out_dma = msg->in_dma + msg->key_bytes; + + return HISI_QM_DRVCRYPT_NO_ERR; +} + +static enum hisi_drv_status +hpre_rsa_crt_decrypt_bn2bin(struct hpre_rsa_msg *msg, + struct drvcrypt_rsa_ed *rsa_data) +{ + struct rsa_keypair *key = rsa_data->key.key; + uint32_t p_bytes = msg->key_bytes >> 1; + uint32_t dq_len = crypto_bignum_num_bytes(key->dq); + uint32_t dp_len = crypto_bignum_num_bytes(key->dp); + uint32_t q_len = crypto_bignum_num_bytes(key->q); + uint32_t p_len = crypto_bignum_num_bytes(key->p); + uint32_t qp_len = crypto_bignum_num_bytes(key->qp); + uint8_t *dq = msg->prikey; + uint8_t *dp = msg->prikey + p_bytes; + uint8_t *q = dp + p_bytes; + uint8_t *p = q + p_bytes; + uint8_t *qp = p + p_bytes; + enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; + + crypto_bignum_bn2bin(key->dq, dq); + crypto_bignum_bn2bin(key->dp, dp); + crypto_bignum_bn2bin(key->q, q); + crypto_bignum_bn2bin(key->p, p); + crypto_bignum_bn2bin(key->qp, qp); + + ret = hpre_bin_from_crypto_bin(dq, dq, p_bytes, dq_len); + if (ret) { + EMSG("Fail to transfer rsa crt dq from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(dp, dp, p_bytes, dp_len); + if (ret) { + EMSG("Fail to transfer rsa crt dp from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(q, q, p_bytes, q_len); + if (ret) { + EMSG("Fail to transfer rsa crt q from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(p, p, p_bytes, p_len); + if (ret) { + EMSG("Fail to transfer rsa crt p from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(qp, qp, p_bytes, qp_len); + if (ret) { + EMSG("Fail to transfer rsa crt qinv from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->cipher.data, + msg->key_bytes, rsa_data->cipher.length); + if (ret) + EMSG("Fail to transfer rsa ciphertext from crypto_bin to hpre_bin"); + + return ret; +} + +static enum hisi_drv_status +hpre_rsa_ncrt_decrypt_bn2bin(struct hpre_rsa_msg *msg, + struct drvcrypt_rsa_ed *rsa_data) +{ + struct rsa_keypair *key = rsa_data->key.key; + uint32_t d_len = 0; + uint32_t n_len = 0; + enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; + uint8_t *n = NULL; + + n = msg->prikey + msg->key_bytes; + + crypto_bignum_bn2bin(key->d, msg->prikey); + crypto_bignum_bn2bin(key->n, n); + d_len = crypto_bignum_num_bytes(key->d); + n_len = crypto_bignum_num_bytes(key->n); + + ret = hpre_bin_from_crypto_bin(msg->prikey, msg->prikey, msg->key_bytes, + d_len); + if (ret) { + EMSG("Fail to transfer rsa ncrt d from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(n, n, msg->key_bytes, n_len); + if (ret) { + EMSG("Fail to transfer rsa ncrt n from crypto_bin to hpre_bin"); + return ret; + } + + ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->cipher.data, + msg->key_bytes, rsa_data->cipher.length); + if (ret) + EMSG("Fail to transfer rsa ciphertext from crypto_bin to hpre_bin"); + + return ret; +} + +static int32_t hpre_rsa_is_crt_mod(struct rsa_keypair *key) +{ + int32_t is_crt = 0; + + if (key->p && crypto_bignum_num_bits(key->p) && + key->q && crypto_bignum_num_bits(key->q) && + key->dp && crypto_bignum_num_bits(key->dp) && + key->dq && crypto_bignum_num_bits(key->dq) && + key->qp && crypto_bignum_num_bits(key->qp)) + is_crt = 1; + + return is_crt; +} + +static TEE_Result hpre_rsa_decrypt_init(struct hpre_rsa_msg *msg, + struct drvcrypt_rsa_ed *rsa_data) +{ + struct rsa_keypair *key = rsa_data->key.key; + size_t n_bytes = rsa_data->key.n_size; + int32_t is_crt = 0; + enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; + + msg->is_private = rsa_data->key.isprivate; + msg->key_bytes = hpre_rsa_get_hw_kbytes(BYTES_TO_BITS(n_bytes)); + if (!msg->key_bytes) + return TEE_ERROR_BAD_PARAMETERS; + + is_crt = hpre_rsa_is_crt_mod(key); + if (is_crt) { + msg->alg_type = HPRE_ALG_NC_CRT; + ret = hpre_rsa_crt_decrypt_alloc(msg); + if (ret) + return TEE_ERROR_BAD_STATE; + + ret = hpre_rsa_crt_decrypt_bn2bin(msg, rsa_data); + if (ret) { + hpre_rsa_params_free(msg); + return TEE_ERROR_BAD_STATE; + } + } else { + msg->alg_type = HPRE_ALG_NC_NCRT; + ret = hpre_rsa_ncrt_decrypt_alloc(msg); + if (ret) + return TEE_ERROR_BAD_STATE; + + ret = hpre_rsa_ncrt_decrypt_bn2bin(msg, rsa_data); + if (ret) { + hpre_rsa_params_free(msg); + return TEE_ERROR_BAD_STATE; + } + } + + return TEE_SUCCESS; +} + +static TEE_Result rsa_nopad_decrypt(struct drvcrypt_rsa_ed *rsa_data) +{ + size_t n_bytes = rsa_data->key.n_size; + struct hpre_rsa_msg msg = { }; + uint32_t offset = 0; + TEE_Result ret = TEE_SUCCESS; + uint8_t *pos = NULL; + + if (rsa_data->cipher.length > n_bytes) { + EMSG("Invalid cipher length[%zu]", rsa_data->cipher.length); + return TEE_ERROR_BAD_PARAMETERS; + } + + ret = hpre_rsa_decrypt_init(&msg, rsa_data); + if (ret) { + EMSG("Fail to init rsa msg"); + return ret; + } + + ret = hpre_rsa_do_task(&msg); + if (ret) + goto decrypt_deinit; + + pos = msg.out + msg.key_bytes - n_bytes; + if (rsa_data->rsa_id == DRVCRYPT_RSA_NOPAD) { + /* Plaintext can not have valid zero data in NOPAD MODE */ + while ((offset < n_bytes - 1) && (pos[offset] == 0)) + offset++; + } + + rsa_data->message.length = n_bytes - offset; + memcpy(rsa_data->message.data, pos + offset, rsa_data->message.length); + +decrypt_deinit: + hpre_rsa_params_free(&msg); + + return ret; +} + +static TEE_Result rsaes_pkcs_v1_5_decode(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *out, size_t *out_len) +{ + size_t em_len = rsa_data->message.length; + uint8_t *em = rsa_data->message.data; + size_t ps_len = 0; + size_t i = 0; + + /* PKCS_V1.5 EM format 0x00 || 0x02 || PS non-zero || 0x00 || M */ + if (em_len < PKCS_V1_5_MSG_MIN_LEN || em[0] != 0 || + em[1] != ENCRYPT_PAD) { + EMSG("Invalid pkcs_v1.5 decode parameter"); + return TEE_ERROR_BAD_PARAMETERS; + } + + for (i = PKCS_V1_5_PS_POS; i < em_len; i++) { + if (em[i] == 0) + break; + } + + if (i >= em_len) { + EMSG("Fail to find zero pos"); + return TEE_ERROR_BAD_PARAMETERS; + } + + ps_len = i - PKCS_V1_5_PS_POS; + if (em_len - ps_len - PKCS_V1_5_FIXED_LEN > *out_len || + ps_len < PKCS_V1_5_PS_MIN_LEN) { + EMSG("Invalid pkcs_v1.5 decode ps_len"); + return TEE_ERROR_BAD_PARAMETERS; + } + + *out_len = em_len - ps_len - PKCS_V1_5_FIXED_LEN; + memcpy(out, em + ps_len + PKCS_V1_5_FIXED_LEN, *out_len); + + return TEE_SUCCESS; +} + +static TEE_Result rsa_pkcs_decrypt(struct drvcrypt_rsa_ed *rsa_data) +{ + uint32_t n_bytes = rsa_data->key.n_size; + struct drvcrypt_rsa_ed rsa_dec_info = *rsa_data; + TEE_Result ret = TEE_SUCCESS; + + /* Alloc pkcs_v1.5 encode message data buf */ + rsa_dec_info.message.data = malloc(n_bytes); + if (!rsa_dec_info.message.data) { + EMSG("Fail to alloc message data buf"); + return TEE_ERROR_OUT_OF_MEMORY; + } + + rsa_dec_info.message.length = n_bytes; + ret = rsa_nopad_decrypt(&rsa_dec_info); + if (ret) + goto free_data; + + ret = rsaes_pkcs_v1_5_decode(&rsa_dec_info, rsa_data->message.data, + &rsa_data->message.length); + if (ret) + EMSG("Fail to get pkcs_v1.5 decode message data"); + +free_data: + free(rsa_dec_info.message.data); + + return ret; +} + +static TEE_Result rsa_oaep_get_seed(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *mask_db, uint8_t *seed) +{ + size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1; + uint8_t mask_db_mgf[OAEP_MAX_HASH_LEN] = { }; + size_t lhash_len = rsa_data->digest_size; + uint8_t *mask_seed = NULL; + TEE_Result ret = TEE_SUCCESS; + + mask_seed = rsa_data->message.data + 1; + + ret = mgf_process(lhash_len, mask_db, db_len, mask_db_mgf, lhash_len, + rsa_data); + if (ret) { + EMSG("Fail to get mask_db mgf result"); + return ret; + } + + return xor_process(mask_seed, mask_db_mgf, seed, lhash_len); +} + +static TEE_Result rsa_oaep_get_db(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *mask_db, uint8_t *seed, uint8_t *db) +{ + size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1; + size_t lhash_len = rsa_data->digest_size; + uint8_t seed_mgf[OAEP_MAX_DB_LEN] = { }; + TEE_Result ret = TEE_SUCCESS; + + ret = mgf_process(lhash_len, seed, lhash_len, seed_mgf, db_len, + rsa_data); + if (ret) { + EMSG("Fail to get seed mgf result"); + return ret; + } + + return xor_process(mask_db, seed_mgf, db, db_len); +} + +static TEE_Result rsa_oaep_get_msg(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *db, uint8_t *out, size_t *out_len) +{ + size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1; + size_t lhash_len = rsa_data->digest_size; + uint8_t hash[OAEP_MAX_HASH_LEN] = { }; + size_t msg_len = 0; + size_t lp_len = 0; + TEE_Result ret = TEE_SUCCESS; + + /* oaep db format lhash || ps zero || 01 || M */ + ret = tee_hash_createdigest(rsa_data->hash_algo, rsa_data->label.data, + rsa_data->label.length, hash, lhash_len); + if (ret) { + EMSG("Fail to get label hash"); + return ret; + } + + if (memcmp(hash, db, lhash_len)) { + EMSG("Hash is not equal"); + return TEE_ERROR_BAD_PARAMETERS; + } + + for (lp_len = lhash_len; lp_len < db_len; lp_len++) { + if (db[lp_len] != 0) + break; + } + + if (lp_len == db_len) { + EMSG("Fail to find fixed 01 in db"); + return TEE_ERROR_BAD_PARAMETERS; + } + + msg_len = db_len - lp_len - 1; + if (msg_len > rsa_data->message.length) { + EMSG("Message space is not enough"); + return TEE_ERROR_SHORT_BUFFER; + } + + *out_len = msg_len; + memcpy(out, db + lp_len + 1, msg_len); + + return TEE_SUCCESS; +} + +static TEE_Result rsa_oaep_decode(struct drvcrypt_rsa_ed *rsa_data, + uint8_t *out, size_t *out_len) +{ + size_t lhash_len = rsa_data->digest_size; + uint8_t seed[OAEP_MAX_HASH_LEN] = { }; + uint8_t db[OAEP_MAX_DB_LEN] = { }; + uint8_t *mask_db = NULL; + TEE_Result ret = TEE_SUCCESS; + + /* oaep format 00 || maskedseed || maskeddb */ + mask_db = rsa_data->message.data + lhash_len + 1; + ret = rsa_oaep_get_seed(rsa_data, mask_db, seed); + if (ret) + return ret; + + ret = rsa_oaep_get_db(rsa_data, mask_db, seed, db); + if (ret) + return ret; + + return rsa_oaep_get_msg(rsa_data, db, out, out_len); +} + +static TEE_Result rsa_oaep_decrypt(struct drvcrypt_rsa_ed *rsa_data) +{ + size_t n_bytes = rsa_data->key.n_size; + struct drvcrypt_rsa_ed rsa_dec_info = *rsa_data; + TEE_Result ret = TEE_SUCCESS; + + /* Alloc oaep encode message data buf */ + rsa_dec_info.message.data = malloc(n_bytes); + if (!rsa_dec_info.message.data) { + EMSG("Fail to alloc message data buf"); + return TEE_ERROR_OUT_OF_MEMORY; + } + + rsa_dec_info.message.length = n_bytes; + ret = rsa_nopad_decrypt(&rsa_dec_info); + if (ret) + goto free_data; + + ret = rsa_oaep_decode(&rsa_dec_info, rsa_data->message.data, + &rsa_data->message.length); + if (ret) + EMSG("Fail to get oaep decode message data"); + +free_data: + free(rsa_dec_info.message.data); + + return ret; +} + +static TEE_Result hpre_rsa_decrypt(struct drvcrypt_rsa_ed *rsa_data) +{ + if (!rsa_data) { + EMSG("Invalid rsa decrypt input parameter"); + return TEE_ERROR_BAD_PARAMETERS; + } + + switch (rsa_data->rsa_id) { + case DRVCRYPT_RSA_NOPAD: + case DRVCRYPT_RSASSA_PKCS_V1_5: + case DRVCRYPT_RSASSA_PSS: + return rsa_nopad_decrypt(rsa_data); + case DRVCRYPT_RSA_PKCS_V1_5: + return rsa_pkcs_decrypt(rsa_data); + case DRVCRYPT_RSA_OAEP: + return rsa_oaep_decrypt(rsa_data); + default: + EMSG("Invalid rsa id"); + return TEE_ERROR_NOT_SUPPORTED; + } +} + +static const struct drvcrypt_rsa driver_rsa = { + .alloc_keypair = sw_crypto_acipher_alloc_rsa_keypair, + .alloc_publickey = hpre_rsa_allocate_publickey, + .free_publickey = hpre_rsa_free_publickey, + .free_keypair = sw_crypto_acipher_free_rsa_keypair, + .gen_keypair = sw_crypto_acipher_gen_rsa_key, + .encrypt = hpre_rsa_encrypt, + .decrypt = hpre_rsa_decrypt, + .optional = { + /* + * If ssa_sign or verify is NULL, the framework will fill + * data format directly by soft calculation. Then call api + * encrypt or decrypt. + */ + .ssa_sign = NULL, + .ssa_verify = NULL, + }, +}; + +TEE_Result hpre_rsa_init(void) +{ + TEE_Result ret = drvcrypt_register_rsa(&driver_rsa); + + if (ret != TEE_SUCCESS) + EMSG("hpre rsa register to crypto fail"); + + return ret; +} + +driver_init(hpre_rsa_init); diff --git a/core/drivers/crypto/hisilicon/hpre_rsa.h b/core/drivers/crypto/hisilicon/hpre_rsa.h new file mode 100644 index 00000000000..5588c913fe5 --- /dev/null +++ b/core/drivers/crypto/hisilicon/hpre_rsa.h @@ -0,0 +1,47 @@ +/* SPDX-License-Identifier: BSD-2-Clause */ +/* + * Copyright (c) 2024, HiSilicon Technologies Co., Ltd. + */ + +#ifndef __TEE_HPRE_RSA_H__ +#define __TEE_HPRE_RSA_H__ + +#include +#include +#include + +#define PKCS_V1_5_MSG_MIN_LEN 11 +#define PKCS_V1_5_PS_MIN_LEN 8 +#define PKCS_V1_5_PS_POS 2 +#define PKCS_V1_5_FIXED_LEN 3 +#define PKCS_V1_5_PS_DATA 0x5a +#define OAEP_MAX_HASH_LEN 64 +#define OAEP_MAX_DB_LEN 512 +#define PRIME_QUALITY_FLAG 1024 +#define HPRE_RSA_GEN_TOTAL_BUF_SIZE(kbytes) ((kbytes) * 7) +#define HPRE_RSA_CRT_TOTAL_BUF_SIZE(kbytes) ((kbytes) * 6) +#define HPRE_RSA_CRT_KEY_BUF_SIZE(kbytes) ((kbytes) >> 10) +#define HPRE_RSA_NCRT_TOTAL_BUF_SIZE(kbytes) ((kbytes) * 4) + +enum pkcs_v1_5_pad_type { + SIGN_PAD = 1, + ENCRYPT_PAD = 2 +}; + +struct hpre_rsa_msg { + uint8_t *pubkey; + paddr_t pubkey_dma; + uint8_t *prikey; + paddr_t prikey_dma; + uint8_t *in; + paddr_t in_dma; + uint8_t *out; + paddr_t out_dma; + uint32_t alg_type; + uint32_t key_bytes; + bool is_private; /* True if private key */ +}; + +TEE_Result hpre_rsa_init(void); + +#endif diff --git a/core/drivers/crypto/hisilicon/sub.mk b/core/drivers/crypto/hisilicon/sub.mk index cf2439645e1..de9d7e9e290 100644 --- a/core/drivers/crypto/hisilicon/sub.mk +++ b/core/drivers/crypto/hisilicon/sub.mk @@ -8,3 +8,4 @@ srcs-$(CFG_HISILICON_ACC_V3) += hpre_main.c srcs-$(CFG_HISILICON_ACC_V3) += hpre_dh.c srcs-$(CFG_HISILICON_ACC_V3) += hpre_ecc.c srcs-$(CFG_HISILICON_ACC_V3) += hpre_montgomery.c +srcs-$(CFG_HISILICON_ACC_V3) += hpre_rsa.c \ No newline at end of file