7575#include <wolfssl/wolfcrypt/wc_mlkem.h>
7676#include <wolfssl/wolfcrypt/hash.h>
7777#include <wolfssl/wolfcrypt/memory.h>
78+ #ifdef WOLF_CRYPTO_CB
79+ #include <wolfssl/wolfcrypt/cryptocb.h>
80+ #endif
7881
7982#ifdef NO_INLINE
8083 #include <wolfssl/wolfcrypt/misc.h>
@@ -298,9 +301,13 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
298301 /* Cache heap pointer. */
299302 key -> heap = heap ;
300303 #ifdef WOLF_CRYPTO_CB
301- /* Cache device id - not used in this algorithm yet. */
304+ key -> devCtx = NULL ;
302305 key -> devId = devId ;
303306 #endif
307+ #ifdef WOLF_PRIVATE_KEY_ID
308+ key -> idLen = 0 ;
309+ key -> labelLen = 0 ;
310+ #endif
304311 key -> flags = 0 ;
305312
306313 /* Zero out all data. */
@@ -322,6 +329,60 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
322329 return ret ;
323330}
324331
332+ #ifdef WOLF_PRIVATE_KEY_ID
333+ int wc_MlKemKey_Init_Id (MlKemKey * key , const unsigned char * id , int len ,
334+ void * heap , int devId )
335+ {
336+ int ret = 0 ;
337+
338+ if (key == NULL || (id == NULL && len != 0 )) {
339+ ret = BAD_FUNC_ARG ;
340+ }
341+ if (ret == 0 && (len < 0 || len > MLKEM_MAX_ID_LEN )) {
342+ ret = BUFFER_E ;
343+ }
344+
345+ if (ret == 0 ) {
346+ /* Use max level so PKCS#11 lookup has a key object to operate on. */
347+ ret = wc_MlKemKey_Init (key , WC_ML_KEM_1024 , heap , devId );
348+ }
349+ if (ret == 0 && id != NULL && len != 0 ) {
350+ XMEMCPY (key -> id , id , (size_t )len );
351+ key -> idLen = len ;
352+ }
353+
354+ return ret ;
355+ }
356+
357+ int wc_MlKemKey_Init_Label (MlKemKey * key , const char * label , void * heap ,
358+ int devId )
359+ {
360+ int ret = 0 ;
361+ int labelLen = 0 ;
362+
363+ if (key == NULL || label == NULL ) {
364+ ret = BAD_FUNC_ARG ;
365+ }
366+ if (ret == 0 ) {
367+ labelLen = (int )XSTRLEN (label );
368+ if ((labelLen == 0 ) || (labelLen > MLKEM_MAX_LABEL_LEN )) {
369+ ret = BUFFER_E ;
370+ }
371+ }
372+
373+ if (ret == 0 ) {
374+ /* Use max level so PKCS#11 lookup has a key object to operate on. */
375+ ret = wc_MlKemKey_Init (key , WC_ML_KEM_1024 , heap , devId );
376+ }
377+ if (ret == 0 ) {
378+ XMEMCPY (key -> label , label , (size_t )labelLen );
379+ key -> labelLen = labelLen ;
380+ }
381+
382+ return ret ;
383+ }
384+ #endif
385+
325386/**
326387 * Free the Kyber key object.
327388 *
@@ -330,7 +391,22 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
330391 */
331392int wc_MlKemKey_Free (MlKemKey * key )
332393{
394+ #if defined(WOLF_CRYPTO_CB ) && defined(WOLF_CRYPTO_CB_FREE )
395+ int ret = 0 ;
396+ #endif
397+
333398 if (key != NULL ) {
399+ #if defined(WOLF_CRYPTO_CB ) && defined(WOLF_CRYPTO_CB_FREE )
400+ if (key -> devId != INVALID_DEVID ) {
401+ ret = wc_CryptoCb_Free (key -> devId , WC_ALGO_TYPE_PK ,
402+ WC_PK_TYPE_PQC_KEM_KEYGEN , WC_PQC_KEM_TYPE_KYBER , (void * )key );
403+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE )) {
404+ return ret ;
405+ }
406+ /* fall-through to software cleanup */
407+ }
408+ (void )ret ;
409+ #endif
334410 /* Dispose of PRF object. */
335411 mlkem_prf_free (& key -> prf );
336412 /* Dispose of hash object. */
@@ -382,6 +458,21 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
382458 ret = BAD_FUNC_ARG ;
383459 }
384460
461+ #ifdef WOLF_CRYPTO_CB
462+ if ((ret == 0 )
463+ #ifndef WOLF_CRYPTO_CB_FIND
464+ && (key -> devId != INVALID_DEVID )
465+ #endif
466+ ) {
467+ ret = wc_CryptoCb_MakePqcKemKey (rng , WC_PQC_KEM_TYPE_KYBER ,
468+ key -> type , key );
469+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
470+ return ret ;
471+ /* fall-through when unavailable */
472+ ret = 0 ;
473+ }
474+ #endif
475+
385476 if (ret == 0 ) {
386477 /* Generate random to use with PRFs.
387478 * Step 1: d is 32 random bytes
@@ -1063,12 +1154,33 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
10631154#ifndef WC_NO_RNG
10641155 int ret = 0 ;
10651156 unsigned char m [WC_ML_KEM_ENC_RAND_SZ ];
1157+ #ifdef WOLF_CRYPTO_CB
1158+ word32 ctlen = 0 ;
1159+ #endif
10661160
10671161 /* Validate parameters. */
10681162 if ((key == NULL ) || (c == NULL ) || (k == NULL ) || (rng == NULL )) {
10691163 ret = BAD_FUNC_ARG ;
10701164 }
10711165
1166+ #ifdef WOLF_CRYPTO_CB
1167+ if (ret == 0 ) {
1168+ ret = wc_MlKemKey_CipherTextSize (key , & ctlen );
1169+ }
1170+ if ((ret == 0 )
1171+ #ifndef WOLF_CRYPTO_CB_FIND
1172+ && (key -> devId != INVALID_DEVID )
1173+ #endif
1174+ ) {
1175+ ret = wc_CryptoCb_PqcEncapsulate (c , ctlen , k , KYBER_SS_SZ , rng ,
1176+ WC_PQC_KEM_TYPE_KYBER , key );
1177+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
1178+ return ret ;
1179+ /* fall-through when unavailable */
1180+ ret = 0 ;
1181+ }
1182+ #endif
1183+
10721184 if (ret == 0 ) {
10731185 /* Generate seed for use with PRFs.
10741186 * Step 1: m is 32 random bytes
@@ -1531,6 +1643,21 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
15311643 ret = BUFFER_E ;
15321644 }
15331645
1646+ #ifdef WOLF_CRYPTO_CB
1647+ if ((ret == 0 )
1648+ #ifndef WOLF_CRYPTO_CB_FIND
1649+ && (key -> devId != INVALID_DEVID )
1650+ #endif
1651+ ) {
1652+ ret = wc_CryptoCb_PqcDecapsulate (ct , ctSz , ss , KYBER_SS_SZ ,
1653+ WC_PQC_KEM_TYPE_KYBER , key );
1654+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
1655+ return ret ;
1656+ /* fall-through when unavailable */
1657+ ret = 0 ;
1658+ }
1659+ #endif
1660+
15341661#if !defined(USE_INTEL_SPEEDUP ) && !defined(WOLFSSL_NO_MALLOC )
15351662 if (ret == 0 ) {
15361663 /* Allocate memory for cipher text that is generated. */
0 commit comments