@@ -5474,7 +5474,7 @@ int wh_Client_MlDsaImportKey(whClientContext* ctx, MlDsaKey* key,
54745474{
54755475 int ret = WH_ERROR_OK ;
54765476 whKeyId key_id = WH_KEYID_ERASED ;
5477- byte buffer [DILITHIUM_MAX_PRV_KEY_SIZE ];
5477+ byte buffer [DILITHIUM_MAX_BOTH_KEY_DER_SIZE ];
54785478 uint16_t buffer_len = 0 ;
54795479
54805480 if ((ctx == NULL ) || (key == NULL ) ||
@@ -5511,7 +5511,7 @@ int wh_Client_MlDsaExportKey(whClientContext* ctx, whKeyId keyId, MlDsaKey* key,
55115511{
55125512 int ret = WH_ERROR_OK ;
55135513 /* buffer cannot be larger than MTU */
5514- byte buffer [DILITHIUM_MAX_PRV_KEY_SIZE ];
5514+ byte buffer [DILITHIUM_MAX_BOTH_KEY_DER_SIZE ];
55155515 uint16_t buffer_len = sizeof (buffer );
55165516
55175517 if ((ctx == NULL ) || WH_KEYID_ISERASED (keyId ) || (key == NULL )) {
@@ -5666,7 +5666,9 @@ int wh_Client_MlDsaMakeExportKey(whClientContext* ctx, int level, int size,
56665666
56675667
56685668int wh_Client_MlDsaSign (whClientContext * ctx , const byte * in , word32 in_len ,
5669- byte * out , word32 * inout_len , MlDsaKey * key )
5669+ byte * out , word32 * inout_len , MlDsaKey * key ,
5670+ const byte * context , byte contextLen ,
5671+ word32 preHashType )
56705672{
56715673 int ret = 0 ;
56725674 whMessageCrypto_MlDsaSignRequest * req = NULL ;
@@ -5709,7 +5711,7 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
57095711 uint16_t action = WC_ALGO_TYPE_PK ;
57105712
57115713 uint16_t req_len = sizeof (whMessageCrypto_GenericRequestHeader ) +
5712- sizeof (* req ) + in_len ;
5714+ sizeof (* req ) + in_len + contextLen ;
57135715 uint32_t options = 0 ;
57145716
57155717 /* Get data pointer from the context to use as request/response storage
@@ -5726,18 +5728,23 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
57265728 ctx -> cryptoAffinity );
57275729
57285730 if (req_len <= WOLFHSM_CFG_COMM_DATA_LEN ) {
5729- uint8_t * req_hash = (uint8_t * )(req + 1 );
5731+ uint8_t * req_data = (uint8_t * )(req + 1 );
57305732 if (evict != 0 ) {
57315733 options |= WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT ;
57325734 }
57335735
57345736 memset (req , 0 , sizeof (* req ));
5735- req -> options = options ;
5736- req -> level = key -> level ;
5737- req -> keyId = key_id ;
5738- req -> sz = in_len ;
5737+ req -> options = options ;
5738+ req -> level = key -> level ;
5739+ req -> keyId = key_id ;
5740+ req -> sz = in_len ;
5741+ req -> contextSz = contextLen ;
5742+ req -> preHashType = preHashType ;
57395743 if ((in != NULL ) && (in_len > 0 )) {
5740- memcpy (req_hash , in , in_len );
5744+ memcpy (req_data , in , in_len );
5745+ }
5746+ if ((context != NULL ) && (contextLen > 0 )) {
5747+ memcpy (req_data + in_len , context , contextLen );
57415748 }
57425749
57435750 /* Send Request */
@@ -5794,8 +5801,9 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
57945801}
57955802
57965803int wh_Client_MlDsaVerify (whClientContext * ctx , const byte * sig , word32 sig_len ,
5797- const byte * msg , word32 msg_len , int * out_res ,
5798- MlDsaKey * key )
5804+ const byte * msg , word32 msg_len , int * out_res ,
5805+ MlDsaKey * key , const byte * context , byte contextLen ,
5806+ word32 preHashType )
57995807{
58005808 int ret = WH_ERROR_OK ;
58015809 uint8_t * dataPtr = NULL ;
@@ -5838,7 +5846,7 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
58385846 uint32_t options = 0 ;
58395847
58405848 uint16_t req_len = sizeof (whMessageCrypto_GenericRequestHeader ) +
5841- sizeof (* req ) + sig_len + msg_len ;
5849+ sizeof (* req ) + sig_len + msg_len + contextLen ;
58425850
58435851
58445852 /* Get data pointer from the context to use as request/response storage
@@ -5864,17 +5872,22 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
58645872 }
58655873
58665874 memset (req , 0 , sizeof (* req ));
5867- req -> options = options ;
5868- req -> level = key -> level ;
5869- req -> keyId = key_id ;
5870- req -> sigSz = sig_len ;
5875+ req -> options = options ;
5876+ req -> level = key -> level ;
5877+ req -> keyId = key_id ;
5878+ req -> sigSz = sig_len ;
58715879 if ((sig != NULL ) && (sig_len > 0 )) {
58725880 memcpy (req_sig , sig , sig_len );
58735881 }
5874- req -> hashSz = msg_len ;
5882+ req -> hashSz = msg_len ;
58755883 if ((msg != NULL ) && (msg_len > 0 )) {
58765884 memcpy (req_hash , msg , msg_len );
58775885 }
5886+ req -> contextSz = contextLen ;
5887+ req -> preHashType = preHashType ;
5888+ if ((context != NULL ) && (contextLen > 0 )) {
5889+ memcpy (req_hash + msg_len , context , contextLen );
5890+ }
58785891
58795892 /* write request */
58805893 ret = wh_Client_SendRequest (ctx , group , action , req_len ,
@@ -5937,7 +5950,7 @@ int wh_Client_MlDsaImportKeyDma(whClientContext* ctx, MlDsaKey* key,
59375950{
59385951 int ret = WH_ERROR_OK ;
59395952 whKeyId key_id = WH_KEYID_ERASED ;
5940- byte buffer [DILITHIUM_MAX_PRV_KEY_SIZE ];
5953+ byte buffer [DILITHIUM_MAX_BOTH_KEY_DER_SIZE ];
59415954 uint16_t buffer_len = 0 ;
59425955
59435956 if ((ctx == NULL ) || (key == NULL ) ||
@@ -5969,7 +5982,7 @@ int wh_Client_MlDsaExportKeyDma(whClientContext* ctx, whKeyId keyId,
59695982 uint8_t * label )
59705983{
59715984 int ret = WH_ERROR_OK ;
5972- byte buffer [DILITHIUM_MAX_PRV_KEY_SIZE ] = {0 };
5985+ byte buffer [DILITHIUM_MAX_BOTH_KEY_DER_SIZE ] = {0 };
59735986 uint16_t buffer_len = sizeof (buffer );
59745987
59755988 if ((ctx == NULL ) || WH_KEYID_ISERASED (keyId ) || (key == NULL )) {
@@ -5993,7 +6006,7 @@ static int _MlDsaMakeKeyDma(whClientContext* ctx, int level,
59936006{
59946007 int ret = WH_ERROR_OK ;
59956008 whKeyId key_id = WH_KEYID_ERASED ;
5996- byte buffer [DILITHIUM_MAX_PRV_KEY_SIZE ];
6009+ byte buffer [DILITHIUM_MAX_BOTH_KEY_DER_SIZE ];
59976010 uint8_t * dataPtr = NULL ;
59986011 whMessageCrypto_MlDsaKeyGenDmaRequest * req = NULL ;
59996012 whMessageCrypto_MlDsaKeyGenDmaResponse * res = NULL ;
@@ -6116,7 +6129,9 @@ int wh_Client_MlDsaMakeExportKeyDma(whClientContext* ctx, int level,
61166129
61176130
61186131int wh_Client_MlDsaSignDma (whClientContext * ctx , const byte * in , word32 in_len ,
6119- byte * out , word32 * out_len , MlDsaKey * key )
6132+ byte * out , word32 * out_len , MlDsaKey * key ,
6133+ const byte * context , byte contextLen ,
6134+ word32 preHashType )
61206135{
61216136 int ret = 0 ;
61226137 whMessageCrypto_MlDsaSignDmaRequest * req = NULL ;
@@ -6158,7 +6173,8 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
61586173 uint16_t action = WC_ALGO_TYPE_PK ;
61596174
61606175 uint16_t req_len =
6161- sizeof (whMessageCrypto_GenericRequestHeader ) + sizeof (* req );
6176+ sizeof (whMessageCrypto_GenericRequestHeader ) + sizeof (* req ) +
6177+ contextLen ;
61626178 uint32_t options = 0 ;
61636179
61646180 /* Get data pointer from the context to use as request/response storage
@@ -6180,9 +6196,14 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
61806196 }
61816197
61826198 memset (req , 0 , sizeof (* req ));
6183- req -> options = options ;
6184- req -> level = key -> level ;
6185- req -> keyId = key_id ;
6199+ req -> options = options ;
6200+ req -> level = key -> level ;
6201+ req -> keyId = key_id ;
6202+ req -> contextSz = contextLen ;
6203+ req -> preHashType = preHashType ;
6204+ if ((context != NULL ) && (contextLen > 0 )) {
6205+ memcpy ((uint8_t * )(req + 1 ), context , contextLen );
6206+ }
61866207
61876208 /* Set up DMA buffers */
61886209 req -> msg .sz = in_len ;
@@ -6255,8 +6276,9 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
62556276}
62566277
62576278int wh_Client_MlDsaVerifyDma (whClientContext * ctx , const byte * sig ,
6258- word32 sig_len , const byte * msg , word32 msg_len ,
6259- int * out_res , MlDsaKey * key )
6279+ word32 sig_len , const byte * msg , word32 msg_len ,
6280+ int * out_res , MlDsaKey * key , const byte * context ,
6281+ byte contextLen , word32 preHashType )
62606282{
62616283 int ret = 0 ;
62626284 whMessageCrypto_MlDsaVerifyDmaRequest * req = NULL ;
@@ -6296,7 +6318,8 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
62966318 uintptr_t msgAddr = 0 ;
62976319
62986320 uint16_t req_len =
6299- sizeof (whMessageCrypto_GenericRequestHeader ) + sizeof (* req );
6321+ sizeof (whMessageCrypto_GenericRequestHeader ) + sizeof (* req ) +
6322+ contextLen ;
63006323
63016324 /* Get data pointer from the context to use as request/response storage
63026325 */
@@ -6317,9 +6340,14 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
63176340 }
63186341
63196342 memset (req , 0 , sizeof (* req ));
6320- req -> options = options ;
6321- req -> level = key -> level ;
6322- req -> keyId = key_id ;
6343+ req -> options = options ;
6344+ req -> level = key -> level ;
6345+ req -> keyId = key_id ;
6346+ req -> contextSz = contextLen ;
6347+ req -> preHashType = preHashType ;
6348+ if ((context != NULL ) && (contextLen > 0 )) {
6349+ memcpy ((uint8_t * )(req + 1 ), context , contextLen );
6350+ }
63236351
63246352 /* Set up DMA buffers */
63256353 req -> sig .sz = sig_len ;
0 commit comments