Skip to content

Commit 3413d88

Browse files
committed
Fix hash copy semantics and add tests
1 parent 17f3332 commit 3413d88

6 files changed

Lines changed: 209 additions & 41 deletions

File tree

scripts/build_ffi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ def build_ffi(local_wolfssl, features):
575575
int wc_ShaUpdate(wc_Sha*, const byte*, word32);
576576
int wc_ShaFinal(wc_Sha*, byte*);
577577
void wc_ShaFree(wc_Sha*);
578+
int wc_ShaCopy(wc_Sha*, wc_Sha*);
578579
"""
579580

580581
if features["SHA256"]:
@@ -584,6 +585,7 @@ def build_ffi(local_wolfssl, features):
584585
int wc_Sha256Update(wc_Sha256*, const byte*, word32);
585586
int wc_Sha256Final(wc_Sha256*, byte*);
586587
void wc_Sha256Free(wc_Sha256*);
588+
int wc_Sha256Copy(wc_Sha256*, wc_Sha256*);
587589
"""
588590

589591
if features["SHA384"]:
@@ -593,6 +595,7 @@ def build_ffi(local_wolfssl, features):
593595
int wc_Sha384Update(wc_Sha384*, const byte*, word32);
594596
int wc_Sha384Final(wc_Sha384*, byte*);
595597
void wc_Sha384Free(wc_Sha384*);
598+
int wc_Sha384Copy(wc_Sha384*, wc_Sha384*);
596599
"""
597600

598601
if features["SHA512"]:
@@ -603,6 +606,7 @@ def build_ffi(local_wolfssl, features):
603606
int wc_Sha512Update(wc_Sha512*, const byte*, word32);
604607
int wc_Sha512Final(wc_Sha512*, byte*);
605608
void wc_Sha512Free(wc_Sha512*);
609+
int wc_Sha512Copy(wc_Sha512*, wc_Sha512*);
606610
"""
607611
if features["SHA3"]:
608612
cdef += """
@@ -623,6 +627,10 @@ def build_ffi(local_wolfssl, features):
623627
void wc_Sha3_256_Free(wc_Sha3*);
624628
void wc_Sha3_384_Free(wc_Sha3*);
625629
void wc_Sha3_512_Free(wc_Sha3*);
630+
int wc_Sha3_224_Copy(wc_Sha3*, wc_Sha3*);
631+
int wc_Sha3_256_Copy(wc_Sha3*, wc_Sha3*);
632+
int wc_Sha3_384_Copy(wc_Sha3*, wc_Sha3*);
633+
int wc_Sha3_512_Copy(wc_Sha3*, wc_Sha3*);
626634
"""
627635

628636
if features["DES3"]:

tests/test_aesgcmstream.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,30 @@ def test_encrypt_aad_bad():
126126
def test_invalid_tag_bytes():
127127
key = "fedcba9876543210"
128128
iv = "0123456789abcdef"
129-
with pytest.raises(ValueError, match="tag_bytes must be between 4 and 16"):
129+
# Out of range
130+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
130131
AesGcmStream(key, iv, tag_bytes=0)
131-
with pytest.raises(ValueError, match="tag_bytes must be between 4 and 16"):
132+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
132133
AesGcmStream(key, iv, tag_bytes=3)
133-
with pytest.raises(ValueError, match="tag_bytes must be between 4 and 16"):
134+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
134135
AesGcmStream(key, iv, tag_bytes=17)
135-
# valid edge cases
136-
AesGcmStream(key, iv, tag_bytes=4)
137-
AesGcmStream(key, iv, tag_bytes=16)
136+
# Non-NIST sizes within 4-16 range
137+
for bad in (5, 6, 7, 9, 10, 11):
138+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
139+
AesGcmStream(key, iv, tag_bytes=bad)
140+
# Valid NIST sizes: verify the resulting tag has the requested length.
141+
for good in (4, 8, 12, 13, 14, 15, 16):
142+
gcm = AesGcmStream(key, iv, tag_bytes=good)
143+
gcm.encrypt("hello world")
144+
tag = gcm.final()
145+
assert len(tag) == good
146+
147+
def test_repeated_construction_destruction():
148+
import gc
149+
key = "fedcba9876543210"
150+
iv = "0123456789abcdef"
151+
for _ in range(1000):
152+
gcm = AesGcmStream(key, iv)
153+
gcm.encrypt("hello world")
154+
del gcm
155+
gc.collect()

tests/test_ciphers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,12 @@ def test_des3_rejects_mode_ctr():
881881
with pytest.raises(ValueError, match="Des3 only supports MODE_CBC"):
882882
Des3(key, MODE_CTR, iv)
883883

884+
def test_des3_rejects_mode_ecb():
885+
key = b"\x01\x23\x45\x67\x89\xab\xcd\xef" * 3
886+
iv = b"\xfe\xdc\xba\x98\x76\x54\x32\x10"
887+
with pytest.raises(ValueError, match="Des3 only supports MODE_CBC"):
888+
Des3(key, MODE_ECB, iv)
889+
884890

885891
if _lib.CHACHA_ENABLED:
886892
def test_chacha_non_block_aligned():
@@ -898,3 +904,15 @@ def test_chacha_non_block_aligned():
898904
def test_chacha_invalid_key_length():
899905
with pytest.raises(ValueError, match="key must be"):
900906
ChaCha(b"\x00" * 20)
907+
908+
909+
if _lib.RSA_ENABLED:
910+
def test_encrypt_oaep_requires_hash_type(vectors):
911+
rsa = RsaPublic(vectors[RsaPublic].key)
912+
with pytest.raises(WolfCryptError, match="Hash type not set"):
913+
rsa.encrypt_oaep(b"plaintext")
914+
915+
def test_decrypt_oaep_requires_hash_type(vectors):
916+
rsa = RsaPrivate(vectors[RsaPrivate].key)
917+
with pytest.raises(WolfCryptError, match="Hash type not set"):
918+
rsa.decrypt_oaep(b"\x00" * rsa.output_size)

tests/test_hashes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,12 @@ def test_hash(hash_cls, vectors):
184184
copy.update("wolfcrypt")
185185

186186
assert hash_obj.hexdigest() == copy.hexdigest() == digest
187+
188+
189+
def test_hash_repeated_construction_destruction(hash_cls):
190+
import gc
191+
for _ in range(1000):
192+
h = hash_new(hash_cls, "wolfcrypt")
193+
h.hexdigest()
194+
del h
195+
gc.collect()

wolfcrypt/ciphers.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -396,33 +396,38 @@ class AesGcmStream(object):
396396
block_size = 16
397397
_key_sizes = [16, 24, 32]
398398
_native_type = "Aes *"
399-
_aad = bytes()
400-
_tag_bytes = 16
401-
_mode = None
399+
_delete = _lib.wc_AesFree
402400

403401
def __init__(self, key, IV, tag_bytes=16):
404402
"""
405403
tag_bytes is the number of bytes to use for the authentication tag during encryption
406404
"""
407405
key = t2b(key)
408406
IV = t2b(IV)
409-
if tag_bytes < 4 or tag_bytes > 16:
410-
raise ValueError("tag_bytes must be between 4 and 16")
407+
# NIST SP 800-38D valid GCM tag lengths: 16, 15, 14, 13, 12, 8, 4 bytes.
408+
if tag_bytes not in (4, 8, 12, 13, 14, 15, 16):
409+
raise ValueError(
410+
"tag_bytes must be one of 4, 8, 12, 13, 14, 15, or 16")
411+
# Per-instance state: AAD, tag length, and current mode (enc/dec).
412+
self._aad = bytes()
411413
self._tag_bytes = tag_bytes
414+
self._mode = None
412415
if len(key) not in self._key_sizes:
413416
raise ValueError("key must be %s in length, not %d" %
414417
(self._key_sizes, len(key)))
418+
self._init_done = False
415419
self._native_object = _ffi.new(self._native_type)
416420
ret = _lib.wc_AesInit(self._native_object, _ffi.NULL, -2)
417421
if ret < 0:
418422
raise WolfCryptError("AES init error (%d)" % ret)
423+
self._init_done = True
419424
ret = _lib.wc_AesGcmInit(self._native_object, key, len(key), IV, len(IV))
420425
if ret < 0:
421426
raise WolfCryptError("Init error (%d)" % ret)
422427

423428
def __del__(self):
424-
if hasattr(self, '_native_object'):
425-
_lib.wc_AesFree(self._native_object)
429+
if getattr(self, '_init_done', False):
430+
self._delete(self._native_object)
426431

427432
def set_aad(self, data):
428433
"""
@@ -446,11 +451,11 @@ def encrypt(self, data):
446451
aad = self._aad
447452
elif self._mode == _DECRYPTION:
448453
raise WolfCryptError("Class instance already in use for decryption")
449-
self._buf = _ffi.new("byte[%d]" % (len(data)))
450-
ret = _lib.wc_AesGcmEncryptUpdate(self._native_object, self._buf, data, len(data), aad, len(aad))
454+
buf = _ffi.new("byte[%d]" % (len(data)))
455+
ret = _lib.wc_AesGcmEncryptUpdate(self._native_object, buf, data, len(data), aad, len(aad))
451456
if ret < 0:
452457
raise WolfCryptError("Encryption error (%d)" % ret)
453-
return bytes(self._buf)
458+
return bytes(buf)
454459

455460
def decrypt(self, data):
456461
"""
@@ -463,11 +468,11 @@ def decrypt(self, data):
463468
aad = self._aad
464469
elif self._mode == _ENCRYPTION:
465470
raise WolfCryptError("Class instance already in use for encryption")
466-
self._buf = _ffi.new("byte[%d]" % (len(data)))
467-
ret = _lib.wc_AesGcmDecryptUpdate(self._native_object, self._buf, data, len(data), aad, len(aad))
471+
buf = _ffi.new("byte[%d]" % (len(data)))
472+
ret = _lib.wc_AesGcmDecryptUpdate(self._native_object, buf, data, len(data), aad, len(aad))
468473
if ret < 0:
469474
raise WolfCryptError("Decryption error (%d)" % ret)
470-
return bytes(self._buf)
475+
return bytes(buf)
471476

472477
def final(self, authTag=None):
473478
"""
@@ -505,7 +510,9 @@ class ChaCha(_Cipher):
505510
_IV_nonce = b""
506511
_IV_counter = 0
507512

508-
def __init__(self, key="", size=32):
513+
def __init__(self, key="", size=32): # pylint: disable=unused-argument
514+
# size is kept for backwards compatibility; key length is now
515+
# derived from the actual key and validated against _key_sizes.
509516
self._native_object = _ffi.new(self._native_type)
510517
self._enc = None
511518
self._dec = None
@@ -552,7 +559,9 @@ def set_iv(self, nonce, counter = 0):
552559
raise ValueError("nonce must be %d bytes, got %d" %
553560
(self._NONCE_SIZE, len(self._IV_nonce)))
554561
self._IV_counter = counter
555-
self._set_key(0)
562+
ret = self._set_key(0)
563+
if ret < 0:
564+
raise WolfCryptError("ChaCha set_iv error (%d)" % ret)
556565

557566
if _lib.CHACHA20_POLY1305_ENABLED:
558567
class ChaCha20Poly1305(object):
@@ -1231,7 +1240,11 @@ def make_key(cls, size, rng=None):
12311240
ret = _lib.wc_ecc_set_rng(ecc.native_object, rng.native_object)
12321241
if ret < 0:
12331242
raise WolfCryptError("Error setting ECC RNG (%d)" % ret)
1234-
ecc._rng = rng
1243+
1244+
# Retain the RNG so it outlives the ECC key. Even outside the
1245+
# timing-resistance path, wolfSSL internals may retain a pointer
1246+
# to the RNG; keeping the reference avoids any UAF risk.
1247+
ecc._rng = rng
12351248

12361249
return ecc
12371250

@@ -1504,6 +1517,10 @@ def make_key(cls, size, rng=None):
15041517
if ret < 0:
15051518
raise WolfCryptError("Key generation error (%d)" % ret)
15061519

1520+
# Retain RNG reference defensively; wolfSSL may retain a pointer
1521+
# internally on some builds.
1522+
ed25519._rng = rng
1523+
15071524
return ed25519
15081525

15091526
def decode_key(self, key, pub = None):
@@ -1706,6 +1723,10 @@ def make_key(cls, size, rng=None):
17061723
if ret < 0:
17071724
raise WolfCryptError("Key generation error (%d)" % ret)
17081725

1726+
# Retain RNG reference defensively; wolfSSL may retain a pointer
1727+
# internally on some builds.
1728+
ed448._rng = rng
1729+
17091730
return ed448
17101731

17111732
def decode_key(self, key, pub = None):
@@ -1979,6 +2000,9 @@ def make_key(cls, mlkem_type, rng=None):
19792000
if ret < 0: # pragma: no cover
19802001
raise WolfCryptError("wc_KyberKey_MakeKey() error (%d)" % ret)
19812002

2003+
# Retain RNG reference defensively.
2004+
mlkem_priv._rng = rng
2005+
19822006
return mlkem_priv
19832007

19842008
@classmethod
@@ -2226,6 +2250,9 @@ def make_key(cls, mldsa_type, rng=None):
22262250
if ret < 0: # pragma: no cover
22272251
raise WolfCryptError("wc_dilithium_make_key() error (%d)" % ret)
22282252

2253+
# Retain RNG reference defensively.
2254+
mldsa_priv._rng = rng
2255+
22292256
return mldsa_priv
22302257

22312258
@property

0 commit comments

Comments
 (0)