7575#include <wolfssl/wolfcrypt/wc_mlkem.h>
7676#include <wolfssl/wolfcrypt/hash.h>
7777#include <wolfssl/wolfcrypt/memory.h>
78+ #ifdef WOLF_CRYPTO_CB
79+ #include <wolfssl/wolfcrypt/cryptocb.h>
80+ #endif
7881
7982#ifdef NO_INLINE
8083 #include <wolfssl/wolfcrypt/misc.h>
@@ -298,9 +301,13 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
298301 /* Cache heap pointer. */
299302 key -> heap = heap ;
300303 #ifdef WOLF_CRYPTO_CB
301- /* Cache device id - not used in this algorithm yet. */
304+ key -> devCtx = NULL ;
302305 key -> devId = devId ;
303306 #endif
307+ #ifdef WOLF_PRIVATE_KEY_ID
308+ key -> idLen = 0 ;
309+ key -> labelLen = 0 ;
310+ #endif
304311 key -> flags = 0 ;
305312
306313 /* Zero out all data. */
@@ -322,6 +329,58 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
322329 return ret ;
323330}
324331
332+ #ifdef WOLF_PRIVATE_KEY_ID
333+ int wc_MlKemKey_Init_Id (MlKemKey * key , int type , const unsigned char * id ,
334+ int len , void * heap , int devId )
335+ {
336+ int ret = 0 ;
337+
338+ if (key == NULL || (id == NULL && len != 0 )) {
339+ ret = BAD_FUNC_ARG ;
340+ }
341+ if (ret == 0 && (len < 0 || len > MLKEM_MAX_ID_LEN )) {
342+ ret = BUFFER_E ;
343+ }
344+
345+ if (ret == 0 ) {
346+ ret = wc_MlKemKey_Init (key , type , heap , devId );
347+ }
348+ if (ret == 0 && id != NULL && len != 0 ) {
349+ XMEMCPY (key -> id , id , (size_t )len );
350+ key -> idLen = len ;
351+ }
352+
353+ return ret ;
354+ }
355+
356+ int wc_MlKemKey_Init_Label (MlKemKey * key , int type , const char * label ,
357+ void * heap , int devId )
358+ {
359+ int ret = 0 ;
360+ int labelLen = 0 ;
361+
362+ if (key == NULL || label == NULL ) {
363+ ret = BAD_FUNC_ARG ;
364+ }
365+ if (ret == 0 ) {
366+ labelLen = (int )XSTRLEN (label );
367+ if ((labelLen == 0 ) || (labelLen > MLKEM_MAX_LABEL_LEN )) {
368+ ret = BUFFER_E ;
369+ }
370+ }
371+
372+ if (ret == 0 ) {
373+ ret = wc_MlKemKey_Init (key , type , heap , devId );
374+ }
375+ if (ret == 0 ) {
376+ XMEMCPY (key -> label , label , (size_t )labelLen );
377+ key -> labelLen = labelLen ;
378+ }
379+
380+ return ret ;
381+ }
382+ #endif
383+
325384/**
326385 * Free the Kyber key object.
327386 *
@@ -330,7 +389,22 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
330389 */
331390int wc_MlKemKey_Free (MlKemKey * key )
332391{
392+ #if defined(WOLF_CRYPTO_CB ) && defined(WOLF_CRYPTO_CB_FREE )
393+ int ret = 0 ;
394+ #endif
395+
333396 if (key != NULL ) {
397+ #if defined(WOLF_CRYPTO_CB ) && defined(WOLF_CRYPTO_CB_FREE )
398+ if (key -> devId != INVALID_DEVID ) {
399+ ret = wc_CryptoCb_Free (key -> devId , WC_ALGO_TYPE_PK ,
400+ WC_PK_TYPE_PQC_KEM_KEYGEN , WC_PQC_KEM_TYPE_KYBER , (void * )key );
401+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE )) {
402+ return ret ;
403+ }
404+ /* fall-through to software cleanup */
405+ }
406+ (void )ret ;
407+ #endif
334408 /* Dispose of PRF object. */
335409 mlkem_prf_free (& key -> prf );
336410 /* Dispose of hash object. */
@@ -382,6 +456,21 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
382456 ret = BAD_FUNC_ARG ;
383457 }
384458
459+ #ifdef WOLF_CRYPTO_CB
460+ if ((ret == 0 )
461+ #ifndef WOLF_CRYPTO_CB_FIND
462+ && (key -> devId != INVALID_DEVID )
463+ #endif
464+ ) {
465+ ret = wc_CryptoCb_MakePqcKemKey (rng , WC_PQC_KEM_TYPE_KYBER ,
466+ key -> type , key );
467+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
468+ return ret ;
469+ /* fall-through when unavailable */
470+ ret = 0 ;
471+ }
472+ #endif
473+
385474 if (ret == 0 ) {
386475 /* Generate random to use with PRFs.
387476 * Step 1: d is 32 random bytes
@@ -1063,12 +1152,33 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
10631152#ifndef WC_NO_RNG
10641153 int ret = 0 ;
10651154 unsigned char m [WC_ML_KEM_ENC_RAND_SZ ];
1155+ #ifdef WOLF_CRYPTO_CB
1156+ word32 ctlen = 0 ;
1157+ #endif
10661158
10671159 /* Validate parameters. */
10681160 if ((key == NULL ) || (c == NULL ) || (k == NULL ) || (rng == NULL )) {
10691161 ret = BAD_FUNC_ARG ;
10701162 }
10711163
1164+ #ifdef WOLF_CRYPTO_CB
1165+ if (ret == 0 ) {
1166+ ret = wc_MlKemKey_CipherTextSize (key , & ctlen );
1167+ }
1168+ if ((ret == 0 )
1169+ #ifndef WOLF_CRYPTO_CB_FIND
1170+ && (key -> devId != INVALID_DEVID )
1171+ #endif
1172+ ) {
1173+ ret = wc_CryptoCb_PqcEncapsulate (c , ctlen , k , KYBER_SS_SZ , rng ,
1174+ WC_PQC_KEM_TYPE_KYBER , key );
1175+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
1176+ return ret ;
1177+ /* fall-through when unavailable */
1178+ ret = 0 ;
1179+ }
1180+ #endif
1181+
10721182 if (ret == 0 ) {
10731183 /* Generate seed for use with PRFs.
10741184 * Step 1: m is 32 random bytes
@@ -1534,6 +1644,21 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
15341644 ret = BUFFER_E ;
15351645 }
15361646
1647+ #ifdef WOLF_CRYPTO_CB
1648+ if ((ret == 0 )
1649+ #ifndef WOLF_CRYPTO_CB_FIND
1650+ && (key -> devId != INVALID_DEVID )
1651+ #endif
1652+ ) {
1653+ ret = wc_CryptoCb_PqcDecapsulate (ct , ctSz , ss , KYBER_SS_SZ ,
1654+ WC_PQC_KEM_TYPE_KYBER , key );
1655+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
1656+ return ret ;
1657+ /* fall-through when unavailable */
1658+ ret = 0 ;
1659+ }
1660+ #endif
1661+
15371662#if !defined(USE_INTEL_SPEEDUP ) && !defined(WOLFSSL_NO_MALLOC )
15381663 if (ret == 0 ) {
15391664 /* Allocate memory for cipher text that is generated. */
0 commit comments