Skip to content

Commit bdb44ae

Browse files
authored
Merge pull request #301 from padelsbach/padelsbach/dilithium-fix
Fix Dilithium issue following WolfCrypt update
2 parents 01fe072 + 3546862 commit bdb44ae

File tree

10 files changed

+451
-102
lines changed

10 files changed

+451
-102
lines changed

src/wh_client_crypto.c

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5474,7 +5474,7 @@ int wh_Client_MlDsaImportKey(whClientContext* ctx, MlDsaKey* key,
54745474
{
54755475
int ret = WH_ERROR_OK;
54765476
whKeyId key_id = WH_KEYID_ERASED;
5477-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
5477+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
54785478
uint16_t buffer_len = 0;
54795479

54805480
if ((ctx == NULL) || (key == NULL) ||
@@ -5511,7 +5511,7 @@ int wh_Client_MlDsaExportKey(whClientContext* ctx, whKeyId keyId, MlDsaKey* key,
55115511
{
55125512
int ret = WH_ERROR_OK;
55135513
/* buffer cannot be larger than MTU */
5514-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
5514+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
55155515
uint16_t buffer_len = sizeof(buffer);
55165516

55175517
if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) {
@@ -5666,7 +5666,9 @@ int wh_Client_MlDsaMakeExportKey(whClientContext* ctx, int level, int size,
56665666

56675667

56685668
int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
5669-
byte* out, word32* inout_len, MlDsaKey* key)
5669+
byte* out, word32* inout_len, MlDsaKey* key,
5670+
const byte* context, byte contextLen,
5671+
word32 preHashType)
56705672
{
56715673
int ret = 0;
56725674
whMessageCrypto_MlDsaSignRequest* req = NULL;
@@ -5709,7 +5711,7 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
57095711
uint16_t action = WC_ALGO_TYPE_PK;
57105712

57115713
uint16_t req_len = sizeof(whMessageCrypto_GenericRequestHeader) +
5712-
sizeof(*req) + in_len;
5714+
sizeof(*req) + in_len + contextLen;
57135715
uint32_t options = 0;
57145716

57155717
/* Get data pointer from the context to use as request/response storage
@@ -5726,18 +5728,23 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
57265728
ctx->cryptoAffinity);
57275729

57285730
if (req_len <= WOLFHSM_CFG_COMM_DATA_LEN) {
5729-
uint8_t* req_hash = (uint8_t*)(req + 1);
5731+
uint8_t* req_data = (uint8_t*)(req + 1);
57305732
if (evict != 0) {
57315733
options |= WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT;
57325734
}
57335735

57345736
memset(req, 0, sizeof(*req));
5735-
req->options = options;
5736-
req->level = key->level;
5737-
req->keyId = key_id;
5738-
req->sz = in_len;
5737+
req->options = options;
5738+
req->level = key->level;
5739+
req->keyId = key_id;
5740+
req->sz = in_len;
5741+
req->contextSz = contextLen;
5742+
req->preHashType = preHashType;
57395743
if ((in != NULL) && (in_len > 0)) {
5740-
memcpy(req_hash, in, in_len);
5744+
memcpy(req_data, in, in_len);
5745+
}
5746+
if ((context != NULL) && (contextLen > 0)) {
5747+
memcpy(req_data + in_len, context, contextLen);
57415748
}
57425749

57435750
/* Send Request */
@@ -5794,8 +5801,9 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
57945801
}
57955802

57965803
int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
5797-
const byte* msg, word32 msg_len, int* out_res,
5798-
MlDsaKey* key)
5804+
const byte* msg, word32 msg_len, int* out_res,
5805+
MlDsaKey* key, const byte* context, byte contextLen,
5806+
word32 preHashType)
57995807
{
58005808
int ret = WH_ERROR_OK;
58015809
uint8_t* dataPtr = NULL;
@@ -5838,7 +5846,7 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
58385846
uint32_t options = 0;
58395847

58405848
uint16_t req_len = sizeof(whMessageCrypto_GenericRequestHeader) +
5841-
sizeof(*req) + sig_len + msg_len;
5849+
sizeof(*req) + sig_len + msg_len + contextLen;
58425850

58435851

58445852
/* Get data pointer from the context to use as request/response storage
@@ -5864,17 +5872,22 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
58645872
}
58655873

58665874
memset(req, 0, sizeof(*req));
5867-
req->options = options;
5868-
req->level = key->level;
5869-
req->keyId = key_id;
5870-
req->sigSz = sig_len;
5875+
req->options = options;
5876+
req->level = key->level;
5877+
req->keyId = key_id;
5878+
req->sigSz = sig_len;
58715879
if ((sig != NULL) && (sig_len > 0)) {
58725880
memcpy(req_sig, sig, sig_len);
58735881
}
5874-
req->hashSz = msg_len;
5882+
req->hashSz = msg_len;
58755883
if ((msg != NULL) && (msg_len > 0)) {
58765884
memcpy(req_hash, msg, msg_len);
58775885
}
5886+
req->contextSz = contextLen;
5887+
req->preHashType = preHashType;
5888+
if ((context != NULL) && (contextLen > 0)) {
5889+
memcpy(req_hash + msg_len, context, contextLen);
5890+
}
58785891

58795892
/* write request */
58805893
ret = wh_Client_SendRequest(ctx, group, action, req_len,
@@ -5937,7 +5950,7 @@ int wh_Client_MlDsaImportKeyDma(whClientContext* ctx, MlDsaKey* key,
59375950
{
59385951
int ret = WH_ERROR_OK;
59395952
whKeyId key_id = WH_KEYID_ERASED;
5940-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
5953+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
59415954
uint16_t buffer_len = 0;
59425955

59435956
if ((ctx == NULL) || (key == NULL) ||
@@ -5969,7 +5982,7 @@ int wh_Client_MlDsaExportKeyDma(whClientContext* ctx, whKeyId keyId,
59695982
uint8_t* label)
59705983
{
59715984
int ret = WH_ERROR_OK;
5972-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE] = {0};
5985+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE] = {0};
59735986
uint16_t buffer_len = sizeof(buffer);
59745987

59755988
if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) {
@@ -5993,7 +6006,7 @@ static int _MlDsaMakeKeyDma(whClientContext* ctx, int level,
59936006
{
59946007
int ret = WH_ERROR_OK;
59956008
whKeyId key_id = WH_KEYID_ERASED;
5996-
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
6009+
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
59976010
uint8_t* dataPtr = NULL;
59986011
whMessageCrypto_MlDsaKeyGenDmaRequest* req = NULL;
59996012
whMessageCrypto_MlDsaKeyGenDmaResponse* res = NULL;
@@ -6116,7 +6129,9 @@ int wh_Client_MlDsaMakeExportKeyDma(whClientContext* ctx, int level,
61166129

61176130

61186131
int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
6119-
byte* out, word32* out_len, MlDsaKey* key)
6132+
byte* out, word32* out_len, MlDsaKey* key,
6133+
const byte* context, byte contextLen,
6134+
word32 preHashType)
61206135
{
61216136
int ret = 0;
61226137
whMessageCrypto_MlDsaSignDmaRequest* req = NULL;
@@ -6158,7 +6173,8 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
61586173
uint16_t action = WC_ALGO_TYPE_PK;
61596174

61606175
uint16_t req_len =
6161-
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req);
6176+
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req) +
6177+
contextLen;
61626178
uint32_t options = 0;
61636179

61646180
/* Get data pointer from the context to use as request/response storage
@@ -6180,9 +6196,14 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
61806196
}
61816197

61826198
memset(req, 0, sizeof(*req));
6183-
req->options = options;
6184-
req->level = key->level;
6185-
req->keyId = key_id;
6199+
req->options = options;
6200+
req->level = key->level;
6201+
req->keyId = key_id;
6202+
req->contextSz = contextLen;
6203+
req->preHashType = preHashType;
6204+
if ((context != NULL) && (contextLen > 0)) {
6205+
memcpy((uint8_t*)(req + 1), context, contextLen);
6206+
}
61866207

61876208
/* Set up DMA buffers */
61886209
req->msg.sz = in_len;
@@ -6255,8 +6276,9 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
62556276
}
62566277

62576278
int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
6258-
word32 sig_len, const byte* msg, word32 msg_len,
6259-
int* out_res, MlDsaKey* key)
6279+
word32 sig_len, const byte* msg, word32 msg_len,
6280+
int* out_res, MlDsaKey* key, const byte* context,
6281+
byte contextLen, word32 preHashType)
62606282
{
62616283
int ret = 0;
62626284
whMessageCrypto_MlDsaVerifyDmaRequest* req = NULL;
@@ -6296,7 +6318,8 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
62966318
uintptr_t msgAddr = 0;
62976319

62986320
uint16_t req_len =
6299-
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req);
6321+
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req) +
6322+
contextLen;
63006323

63016324
/* Get data pointer from the context to use as request/response storage
63026325
*/
@@ -6317,9 +6340,14 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
63176340
}
63186341

63196342
memset(req, 0, sizeof(*req));
6320-
req->options = options;
6321-
req->level = key->level;
6322-
req->keyId = key_id;
6343+
req->options = options;
6344+
req->level = key->level;
6345+
req->keyId = key_id;
6346+
req->contextSz = contextLen;
6347+
req->preHashType = preHashType;
6348+
if ((context != NULL) && (contextLen > 0)) {
6349+
memcpy((uint8_t*)(req + 1), context, contextLen);
6350+
}
63236351

63246352
/* Set up DMA buffers */
63256353
req->sig.sz = sig_len;

src/wh_client_cryptocb.c

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -642,12 +642,15 @@ static int _handlePqcSign(whClientContext* ctx, wc_CryptoInfo* info, int useDma)
642642
int ret = CRYPTOCB_UNAVAILABLE;
643643

644644
/* Extract info parameters */
645-
const byte* in = info->pk.pqc_sign.in;
646-
word32 in_len = info->pk.pqc_sign.inlen;
647-
byte* out = info->pk.pqc_sign.out;
648-
word32* out_len = info->pk.pqc_sign.outlen;
649-
void* key = info->pk.pqc_sign.key;
650-
int type = info->pk.pqc_sign.type;
645+
const byte* in = info->pk.pqc_sign.in;
646+
word32 in_len = info->pk.pqc_sign.inlen;
647+
byte* out = info->pk.pqc_sign.out;
648+
word32* out_len = info->pk.pqc_sign.outlen;
649+
void* key = info->pk.pqc_sign.key;
650+
int type = info->pk.pqc_sign.type;
651+
const byte* context = info->pk.pqc_sign.context;
652+
byte contextLen = info->pk.pqc_sign.contextLen;
653+
word32 preHashType = info->pk.pqc_sign.preHashType;
651654

652655
#ifndef WOLFHSM_CFG_DMA
653656
if (useDma) {
@@ -661,13 +664,15 @@ static int _handlePqcSign(whClientContext* ctx, wc_CryptoInfo* info, int useDma)
661664
case WC_PQC_SIG_TYPE_DILITHIUM:
662665
#ifdef WOLFHSM_CFG_DMA
663666
if (useDma) {
664-
ret =
665-
wh_Client_MlDsaSignDma(ctx, in, in_len, out, out_len, key);
667+
ret = wh_Client_MlDsaSignDma(ctx, in, in_len, out, out_len,
668+
key, context, contextLen,
669+
preHashType);
666670
}
667671
else
668672
#endif /* WOLFHSM_CFG_DMA */
669673
{
670-
ret = wh_Client_MlDsaSign(ctx, in, in_len, out, out_len, key);
674+
ret = wh_Client_MlDsaSign(ctx, in, in_len, out, out_len, key,
675+
context, contextLen, preHashType);
671676
}
672677
break;
673678
#endif /* HAVE_DILITHIUM */
@@ -688,13 +693,16 @@ static int _handlePqcVerify(whClientContext* ctx, wc_CryptoInfo* info,
688693
int ret = CRYPTOCB_UNAVAILABLE;
689694

690695
/* Extract info parameters */
691-
const byte* sig = info->pk.pqc_verify.sig;
692-
word32 sig_len = info->pk.pqc_verify.siglen;
693-
const byte* msg = info->pk.pqc_verify.msg;
694-
word32 msg_len = info->pk.pqc_verify.msglen;
695-
int* res = info->pk.pqc_verify.res;
696-
void* key = info->pk.pqc_verify.key;
697-
int type = info->pk.pqc_verify.type;
696+
const byte* sig = info->pk.pqc_verify.sig;
697+
word32 sig_len = info->pk.pqc_verify.siglen;
698+
const byte* msg = info->pk.pqc_verify.msg;
699+
word32 msg_len = info->pk.pqc_verify.msglen;
700+
int* res = info->pk.pqc_verify.res;
701+
void* key = info->pk.pqc_verify.key;
702+
int type = info->pk.pqc_verify.type;
703+
const byte* context = info->pk.pqc_verify.context;
704+
byte contextLen = info->pk.pqc_verify.contextLen;
705+
word32 preHashType = info->pk.pqc_verify.preHashType;
698706

699707
#ifndef WOLFHSM_CFG_DMA
700708
if (useDma) {
@@ -709,13 +717,15 @@ static int _handlePqcVerify(whClientContext* ctx, wc_CryptoInfo* info,
709717
#ifdef WOLFHSM_CFG_DMA
710718
if (useDma) {
711719
ret = wh_Client_MlDsaVerifyDma(ctx, sig, sig_len, msg, msg_len,
712-
res, key);
720+
res, key, context, contextLen,
721+
preHashType);
713722
}
714723
else
715724
#endif /* WOLFHSM_CFG_DMA */
716725
{
717-
ret = wh_Client_MlDsaVerify(ctx, sig, sig_len, msg, msg_len, res,
718-
key);
726+
ret = wh_Client_MlDsaVerify(ctx, sig, sig_len, msg, msg_len,
727+
res, key, context, contextLen,
728+
preHashType);
719729
}
720730
break;
721731
#endif /* HAVE_DILITHIUM */

src/wh_message_crypto.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,8 @@ int wh_MessageCrypto_TranslateMlDsaSignRequest(
788788
WH_T32(magic, dest, src, level);
789789
WH_T32(magic, dest, src, keyId);
790790
WH_T32(magic, dest, src, sz);
791+
WH_T32(magic, dest, src, contextSz);
792+
WH_T32(magic, dest, src, preHashType);
791793
return 0;
792794
}
793795

@@ -816,6 +818,8 @@ int wh_MessageCrypto_TranslateMlDsaVerifyRequest(
816818
WH_T32(magic, dest, src, keyId);
817819
WH_T32(magic, dest, src, sigSz);
818820
WH_T32(magic, dest, src, hashSz);
821+
WH_T32(magic, dest, src, contextSz);
822+
WH_T32(magic, dest, src, preHashType);
819823
return 0;
820824
}
821825

@@ -1030,6 +1034,8 @@ int wh_MessageCrypto_TranslateMlDsaSignDmaRequest(
10301034
WH_T32(magic, dest, src, options);
10311035
WH_T32(magic, dest, src, level);
10321036
WH_T32(magic, dest, src, keyId);
1037+
WH_T32(magic, dest, src, contextSz);
1038+
WH_T32(magic, dest, src, preHashType);
10331039

10341040
return 0;
10351041
}
@@ -1079,6 +1085,8 @@ int wh_MessageCrypto_TranslateMlDsaVerifyDmaRequest(
10791085
WH_T32(magic, dest, src, options);
10801086
WH_T32(magic, dest, src, level);
10811087
WH_T32(magic, dest, src, keyId);
1088+
WH_T32(magic, dest, src, contextSz);
1089+
WH_T32(magic, dest, src, preHashType);
10821090

10831091
return 0;
10841092
}

0 commit comments

Comments
 (0)