Skip to content

Commit be84f86

Browse files
committed
Fix Dilithium issue following WolfCrypt update
Initial fix Updates Buffer size fixes
1 parent 01fe072 commit be84f86

9 files changed

Lines changed: 488 additions & 95 deletions

src/wh_client_crypto.c

Lines changed: 94 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5474,7 +5474,7 @@ int wh_Client_MlDsaImportKey(whClientContext* ctx, MlDsaKey* key,
54745474
{
54755475
int ret = WH_ERROR_OK;
54765476
whKeyId key_id = WH_KEYID_ERASED;
5477-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
5477+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
54785478
uint16_t buffer_len = 0;
54795479

54805480
if ((ctx == NULL) || (key == NULL) ||
@@ -5511,7 +5511,7 @@ int wh_Client_MlDsaExportKey(whClientContext* ctx, whKeyId keyId, MlDsaKey* key,
55115511
{
55125512
int ret = WH_ERROR_OK;
55135513
/* buffer cannot be larger than MTU */
5514-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
5514+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
55155515
uint16_t buffer_len = sizeof(buffer);
55165516

55175517
if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) {
@@ -5665,8 +5665,10 @@ int wh_Client_MlDsaMakeExportKey(whClientContext* ctx, int level, int size,
56655665
}
56665666

56675667

5668-
int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
5669-
byte* out, word32* inout_len, MlDsaKey* key)
5668+
int wh_Client_MlDsaSign_ex(whClientContext* ctx, const byte* in, word32 in_len,
5669+
byte* out, word32* inout_len, MlDsaKey* key,
5670+
const byte* context, byte contextLen,
5671+
word32 preHashType)
56705672
{
56715673
int ret = 0;
56725674
whMessageCrypto_MlDsaSignRequest* req = NULL;
@@ -5709,7 +5711,7 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
57095711
uint16_t action = WC_ALGO_TYPE_PK;
57105712

57115713
uint16_t req_len = sizeof(whMessageCrypto_GenericRequestHeader) +
5712-
sizeof(*req) + in_len;
5714+
sizeof(*req) + in_len + contextLen;
57135715
uint32_t options = 0;
57145716

57155717
/* Get data pointer from the context to use as request/response storage
@@ -5726,18 +5728,23 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
57265728
ctx->cryptoAffinity);
57275729

57285730
if (req_len <= WOLFHSM_CFG_COMM_DATA_LEN) {
5729-
uint8_t* req_hash = (uint8_t*)(req + 1);
5731+
uint8_t* req_data = (uint8_t*)(req + 1);
57305732
if (evict != 0) {
57315733
options |= WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT;
57325734
}
57335735

57345736
memset(req, 0, sizeof(*req));
5735-
req->options = options;
5736-
req->level = key->level;
5737-
req->keyId = key_id;
5738-
req->sz = in_len;
5737+
req->options = options;
5738+
req->level = key->level;
5739+
req->keyId = key_id;
5740+
req->sz = in_len;
5741+
req->contextSz = contextLen;
5742+
req->preHashType = preHashType;
57395743
if ((in != NULL) && (in_len > 0)) {
5740-
memcpy(req_hash, in, in_len);
5744+
memcpy(req_data, in, in_len);
5745+
}
5746+
if ((context != NULL) && (contextLen > 0)) {
5747+
memcpy(req_data + in_len, context, contextLen);
57415748
}
57425749

57435750
/* Send Request */
@@ -5793,9 +5800,17 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
57935800
return ret;
57945801
}
57955802

5796-
int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
5797-
const byte* msg, word32 msg_len, int* out_res,
5798-
MlDsaKey* key)
5803+
int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
5804+
byte* out, word32* out_len, MlDsaKey* key)
5805+
{
5806+
return wh_Client_MlDsaSign_ex(ctx, in, in_len, out, out_len, key,
5807+
NULL, 0, WC_HASH_TYPE_NONE);
5808+
}
5809+
5810+
int wh_Client_MlDsaVerify_ex(whClientContext* ctx, const byte* sig, word32 sig_len,
5811+
const byte* msg, word32 msg_len, int* out_res,
5812+
MlDsaKey* key, const byte* context, byte contextLen,
5813+
word32 preHashType)
57995814
{
58005815
int ret = WH_ERROR_OK;
58015816
uint8_t* dataPtr = NULL;
@@ -5838,7 +5853,7 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
58385853
uint32_t options = 0;
58395854

58405855
uint16_t req_len = sizeof(whMessageCrypto_GenericRequestHeader) +
5841-
sizeof(*req) + sig_len + msg_len;
5856+
sizeof(*req) + sig_len + msg_len + contextLen;
58425857

58435858

58445859
/* Get data pointer from the context to use as request/response storage
@@ -5864,17 +5879,22 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
58645879
}
58655880

58665881
memset(req, 0, sizeof(*req));
5867-
req->options = options;
5868-
req->level = key->level;
5869-
req->keyId = key_id;
5870-
req->sigSz = sig_len;
5882+
req->options = options;
5883+
req->level = key->level;
5884+
req->keyId = key_id;
5885+
req->sigSz = sig_len;
58715886
if ((sig != NULL) && (sig_len > 0)) {
58725887
memcpy(req_sig, sig, sig_len);
58735888
}
5874-
req->hashSz = msg_len;
5889+
req->hashSz = msg_len;
58755890
if ((msg != NULL) && (msg_len > 0)) {
58765891
memcpy(req_hash, msg, msg_len);
58775892
}
5893+
req->contextSz = contextLen;
5894+
req->preHashType = preHashType;
5895+
if ((context != NULL) && (contextLen > 0)) {
5896+
memcpy(req_hash + msg_len, context, contextLen);
5897+
}
58785898

58795899
/* write request */
58805900
ret = wh_Client_SendRequest(ctx, group, action, req_len,
@@ -5917,6 +5937,14 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
59175937
return ret;
59185938
}
59195939

5940+
int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
5941+
const byte* msg, word32 msg_len, int* res,
5942+
MlDsaKey* key)
5943+
{
5944+
return wh_Client_MlDsaVerify_ex(ctx, sig, sig_len, msg, msg_len, res, key,
5945+
NULL, 0, WC_HASH_TYPE_NONE);
5946+
}
5947+
59205948
int wh_Client_MlDsaCheckPrivKey(whClientContext* ctx, MlDsaKey* key,
59215949
const byte* pubKey, word32 pubKeySz)
59225950
{
@@ -5937,7 +5965,7 @@ int wh_Client_MlDsaImportKeyDma(whClientContext* ctx, MlDsaKey* key,
59375965
{
59385966
int ret = WH_ERROR_OK;
59395967
whKeyId key_id = WH_KEYID_ERASED;
5940-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
5968+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
59415969
uint16_t buffer_len = 0;
59425970

59435971
if ((ctx == NULL) || (key == NULL) ||
@@ -5969,7 +5997,7 @@ int wh_Client_MlDsaExportKeyDma(whClientContext* ctx, whKeyId keyId,
59695997
uint8_t* label)
59705998
{
59715999
int ret = WH_ERROR_OK;
5972-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE] = {0};
6000+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE] = {0};
59736001
uint16_t buffer_len = sizeof(buffer);
59746002

59756003
if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) {
@@ -5993,7 +6021,7 @@ static int _MlDsaMakeKeyDma(whClientContext* ctx, int level,
59936021
{
59946022
int ret = WH_ERROR_OK;
59956023
whKeyId key_id = WH_KEYID_ERASED;
5996-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
6024+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
59976025
uint8_t* dataPtr = NULL;
59986026
whMessageCrypto_MlDsaKeyGenDmaRequest* req = NULL;
59996027
whMessageCrypto_MlDsaKeyGenDmaResponse* res = NULL;
@@ -6115,8 +6143,10 @@ int wh_Client_MlDsaMakeExportKeyDma(whClientContext* ctx, int level,
61156143
}
61166144

61176145

6118-
int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
6119-
byte* out, word32* out_len, MlDsaKey* key)
6146+
int wh_Client_MlDsaSignDma_ex(whClientContext* ctx, const byte* in, word32 in_len,
6147+
byte* out, word32* out_len, MlDsaKey* key,
6148+
const byte* context, byte contextLen,
6149+
word32 preHashType)
61206150
{
61216151
int ret = 0;
61226152
whMessageCrypto_MlDsaSignDmaRequest* req = NULL;
@@ -6158,7 +6188,8 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
61586188
uint16_t action = WC_ALGO_TYPE_PK;
61596189

61606190
uint16_t req_len =
6161-
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req);
6191+
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req) +
6192+
contextLen;
61626193
uint32_t options = 0;
61636194

61646195
/* Get data pointer from the context to use as request/response storage
@@ -6180,9 +6211,14 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
61806211
}
61816212

61826213
memset(req, 0, sizeof(*req));
6183-
req->options = options;
6184-
req->level = key->level;
6185-
req->keyId = key_id;
6214+
req->options = options;
6215+
req->level = key->level;
6216+
req->keyId = key_id;
6217+
req->contextSz = contextLen;
6218+
req->preHashType = preHashType;
6219+
if ((context != NULL) && (contextLen > 0)) {
6220+
memcpy((uint8_t*)(req + 1), context, contextLen);
6221+
}
61866222

61876223
/* Set up DMA buffers */
61886224
req->msg.sz = in_len;
@@ -6254,9 +6290,17 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
62546290
return ret;
62556291
}
62566292

6257-
int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
6258-
word32 sig_len, const byte* msg, word32 msg_len,
6259-
int* out_res, MlDsaKey* key)
6293+
int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
6294+
byte* out, word32* out_len, MlDsaKey* key)
6295+
{
6296+
return wh_Client_MlDsaSignDma_ex(ctx, in, in_len, out, out_len, key,
6297+
NULL, 0, WC_HASH_TYPE_NONE);
6298+
}
6299+
6300+
int wh_Client_MlDsaVerifyDma_ex(whClientContext* ctx, const byte* sig,
6301+
word32 sig_len, const byte* msg, word32 msg_len,
6302+
int* out_res, MlDsaKey* key, const byte* context,
6303+
byte contextLen, word32 preHashType)
62606304
{
62616305
int ret = 0;
62626306
whMessageCrypto_MlDsaVerifyDmaRequest* req = NULL;
@@ -6296,7 +6340,8 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
62966340
uintptr_t msgAddr = 0;
62976341

62986342
uint16_t req_len =
6299-
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req);
6343+
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req) +
6344+
contextLen;
63006345

63016346
/* Get data pointer from the context to use as request/response storage
63026347
*/
@@ -6317,9 +6362,14 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
63176362
}
63186363

63196364
memset(req, 0, sizeof(*req));
6320-
req->options = options;
6321-
req->level = key->level;
6322-
req->keyId = key_id;
6365+
req->options = options;
6366+
req->level = key->level;
6367+
req->keyId = key_id;
6368+
req->contextSz = contextLen;
6369+
req->preHashType = preHashType;
6370+
if ((context != NULL) && (contextLen > 0)) {
6371+
memcpy((uint8_t*)(req + 1), context, contextLen);
6372+
}
63236373

63246374
/* Set up DMA buffers */
63256375
req->sig.sz = sig_len;
@@ -6391,6 +6441,14 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
63916441
return ret;
63926442
}
63936443

6444+
int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
6445+
word32 sig_len, const byte* msg, word32 msg_len,
6446+
int* res, MlDsaKey* key)
6447+
{
6448+
return wh_Client_MlDsaVerifyDma_ex(ctx, sig, sig_len, msg, msg_len, res, key,
6449+
NULL, 0, WC_HASH_TYPE_NONE);
6450+
}
6451+
63946452

63956453
int wh_Client_MlDsaCheckPrivKeyDma(whClientContext* ctx, MlDsaKey* key,
63966454
const byte* pubKey, word32 pubKeySz)

src/wh_client_cryptocb.c

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,29 @@ static int _handlePqcSigCheckPrivKey(whClientContext* ctx, wc_CryptoInfo* info,
6565
int useDma);
6666
#endif /* HAVE_DILITHIUM || HAVE_FALCON */
6767

68+
/* Internal extended ML-DSA helpers (not part of public API) */
69+
#ifdef HAVE_DILITHIUM
70+
int wh_Client_MlDsaSign_ex(whClientContext* ctx, const byte* in, word32 in_len,
71+
byte* out, word32* out_len, MlDsaKey* key,
72+
const byte* context, byte contextLen,
73+
word32 preHashType);
74+
int wh_Client_MlDsaVerify_ex(whClientContext* ctx, const byte* sig,
75+
word32 sig_len, const byte* msg, word32 msg_len,
76+
int* res, MlDsaKey* key, const byte* context,
77+
byte contextLen, word32 preHashType);
78+
#ifdef WOLFHSM_CFG_DMA
79+
int wh_Client_MlDsaSignDma_ex(whClientContext* ctx, const byte* in,
80+
word32 in_len, byte* out, word32* out_len,
81+
MlDsaKey* key, const byte* context,
82+
byte contextLen, word32 preHashType);
83+
int wh_Client_MlDsaVerifyDma_ex(whClientContext* ctx, const byte* sig,
84+
word32 sig_len, const byte* msg,
85+
word32 msg_len, int* res, MlDsaKey* key,
86+
const byte* context, byte contextLen,
87+
word32 preHashType);
88+
#endif /* WOLFHSM_CFG_DMA */
89+
#endif /* HAVE_DILITHIUM */
90+
6891
int wh_Client_CryptoCb(int devId, wc_CryptoInfo* info, void* inCtx)
6992
{
7093
/* III When possible, return wolfCrypt-enumerated errors */
@@ -642,12 +665,15 @@ static int _handlePqcSign(whClientContext* ctx, wc_CryptoInfo* info, int useDma)
642665
int ret = CRYPTOCB_UNAVAILABLE;
643666

644667
/* Extract info parameters */
645-
const byte* in = info->pk.pqc_sign.in;
646-
word32 in_len = info->pk.pqc_sign.inlen;
647-
byte* out = info->pk.pqc_sign.out;
648-
word32* out_len = info->pk.pqc_sign.outlen;
649-
void* key = info->pk.pqc_sign.key;
650-
int type = info->pk.pqc_sign.type;
668+
const byte* in = info->pk.pqc_sign.in;
669+
word32 in_len = info->pk.pqc_sign.inlen;
670+
byte* out = info->pk.pqc_sign.out;
671+
word32* out_len = info->pk.pqc_sign.outlen;
672+
void* key = info->pk.pqc_sign.key;
673+
int type = info->pk.pqc_sign.type;
674+
const byte* context = info->pk.pqc_sign.context;
675+
byte contextLen = info->pk.pqc_sign.contextLen;
676+
word32 preHashType = info->pk.pqc_sign.preHashType;
651677

652678
#ifndef WOLFHSM_CFG_DMA
653679
if (useDma) {
@@ -661,13 +687,15 @@ static int _handlePqcSign(whClientContext* ctx, wc_CryptoInfo* info, int useDma)
661687
case WC_PQC_SIG_TYPE_DILITHIUM:
662688
#ifdef WOLFHSM_CFG_DMA
663689
if (useDma) {
664-
ret =
665-
wh_Client_MlDsaSignDma(ctx, in, in_len, out, out_len, key);
690+
ret = wh_Client_MlDsaSignDma_ex(ctx, in, in_len, out, out_len,
691+
key, context, contextLen,
692+
preHashType);
666693
}
667694
else
668695
#endif /* WOLFHSM_CFG_DMA */
669696
{
670-
ret = wh_Client_MlDsaSign(ctx, in, in_len, out, out_len, key);
697+
ret = wh_Client_MlDsaSign_ex(ctx, in, in_len, out, out_len, key,
698+
context, contextLen, preHashType);
671699
}
672700
break;
673701
#endif /* HAVE_DILITHIUM */
@@ -688,13 +716,16 @@ static int _handlePqcVerify(whClientContext* ctx, wc_CryptoInfo* info,
688716
int ret = CRYPTOCB_UNAVAILABLE;
689717

690718
/* Extract info parameters */
691-
const byte* sig = info->pk.pqc_verify.sig;
692-
word32 sig_len = info->pk.pqc_verify.siglen;
693-
const byte* msg = info->pk.pqc_verify.msg;
694-
word32 msg_len = info->pk.pqc_verify.msglen;
695-
int* res = info->pk.pqc_verify.res;
696-
void* key = info->pk.pqc_verify.key;
697-
int type = info->pk.pqc_verify.type;
719+
const byte* sig = info->pk.pqc_verify.sig;
720+
word32 sig_len = info->pk.pqc_verify.siglen;
721+
const byte* msg = info->pk.pqc_verify.msg;
722+
word32 msg_len = info->pk.pqc_verify.msglen;
723+
int* res = info->pk.pqc_verify.res;
724+
void* key = info->pk.pqc_verify.key;
725+
int type = info->pk.pqc_verify.type;
726+
const byte* context = info->pk.pqc_verify.context;
727+
byte contextLen = info->pk.pqc_verify.contextLen;
728+
word32 preHashType = info->pk.pqc_verify.preHashType;
698729

699730
#ifndef WOLFHSM_CFG_DMA
700731
if (useDma) {
@@ -708,14 +739,16 @@ static int _handlePqcVerify(whClientContext* ctx, wc_CryptoInfo* info,
708739
case WC_PQC_SIG_TYPE_DILITHIUM:
709740
#ifdef WOLFHSM_CFG_DMA
710741
if (useDma) {
711-
ret = wh_Client_MlDsaVerifyDma(ctx, sig, sig_len, msg, msg_len,
712-
res, key);
742+
ret = wh_Client_MlDsaVerifyDma_ex(ctx, sig, sig_len, msg, msg_len,
743+
res, key, context, contextLen,
744+
preHashType);
713745
}
714746
else
715747
#endif /* WOLFHSM_CFG_DMA */
716748
{
717-
ret = wh_Client_MlDsaVerify(ctx, sig, sig_len, msg, msg_len, res,
718-
key);
749+
ret = wh_Client_MlDsaVerify_ex(ctx, sig, sig_len, msg, msg_len,
750+
res, key, context, contextLen,
751+
preHashType);
719752
}
720753
break;
721754
#endif /* HAVE_DILITHIUM */

0 commit comments

Comments
 (0)