Skip to content

Commit 55382b3

Browse files
committed
Fix hash copy semantics and add tests
1 parent 17f3332 commit 55382b3

7 files changed

Lines changed: 220 additions & 44 deletions

File tree

scripts/build_ffi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ def build_ffi(local_wolfssl, features):
575575
int wc_ShaUpdate(wc_Sha*, const byte*, word32);
576576
int wc_ShaFinal(wc_Sha*, byte*);
577577
void wc_ShaFree(wc_Sha*);
578+
int wc_ShaCopy(wc_Sha*, wc_Sha*);
578579
"""
579580

580581
if features["SHA256"]:
@@ -584,6 +585,7 @@ def build_ffi(local_wolfssl, features):
584585
int wc_Sha256Update(wc_Sha256*, const byte*, word32);
585586
int wc_Sha256Final(wc_Sha256*, byte*);
586587
void wc_Sha256Free(wc_Sha256*);
588+
int wc_Sha256Copy(wc_Sha256*, wc_Sha256*);
587589
"""
588590

589591
if features["SHA384"]:
@@ -593,6 +595,7 @@ def build_ffi(local_wolfssl, features):
593595
int wc_Sha384Update(wc_Sha384*, const byte*, word32);
594596
int wc_Sha384Final(wc_Sha384*, byte*);
595597
void wc_Sha384Free(wc_Sha384*);
598+
int wc_Sha384Copy(wc_Sha384*, wc_Sha384*);
596599
"""
597600

598601
if features["SHA512"]:
@@ -603,6 +606,7 @@ def build_ffi(local_wolfssl, features):
603606
int wc_Sha512Update(wc_Sha512*, const byte*, word32);
604607
int wc_Sha512Final(wc_Sha512*, byte*);
605608
void wc_Sha512Free(wc_Sha512*);
609+
int wc_Sha512Copy(wc_Sha512*, wc_Sha512*);
606610
"""
607611
if features["SHA3"]:
608612
cdef += """
@@ -623,6 +627,10 @@ def build_ffi(local_wolfssl, features):
623627
void wc_Sha3_256_Free(wc_Sha3*);
624628
void wc_Sha3_384_Free(wc_Sha3*);
625629
void wc_Sha3_512_Free(wc_Sha3*);
630+
int wc_Sha3_224_Copy(wc_Sha3*, wc_Sha3*);
631+
int wc_Sha3_256_Copy(wc_Sha3*, wc_Sha3*);
632+
int wc_Sha3_384_Copy(wc_Sha3*, wc_Sha3*);
633+
int wc_Sha3_512_Copy(wc_Sha3*, wc_Sha3*);
626634
"""
627635

628636
if features["DES3"]:

tests/test_aesgcmstream.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,31 @@ def test_encrypt_aad_bad():
126126
def test_invalid_tag_bytes():
127127
key = "fedcba9876543210"
128128
iv = "0123456789abcdef"
129-
with pytest.raises(ValueError, match="tag_bytes must be between 4 and 16"):
129+
# Out of range
130+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
130131
AesGcmStream(key, iv, tag_bytes=0)
131-
with pytest.raises(ValueError, match="tag_bytes must be between 4 and 16"):
132+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
132133
AesGcmStream(key, iv, tag_bytes=3)
133-
with pytest.raises(ValueError, match="tag_bytes must be between 4 and 16"):
134+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
134135
AesGcmStream(key, iv, tag_bytes=17)
135-
# valid edge cases
136-
AesGcmStream(key, iv, tag_bytes=4)
137-
AesGcmStream(key, iv, tag_bytes=16)
136+
# Non-NIST sizes within 4-16 range
137+
for bad in (5, 6, 7, 9, 10, 11):
138+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
139+
AesGcmStream(key, iv, tag_bytes=bad)
140+
# Valid NIST sizes: verify the resulting tag has the requested length.
141+
for good in (4, 8, 12, 13, 14, 15, 16):
142+
gcm = AesGcmStream(key, iv, tag_bytes=good)
143+
gcm.encrypt("hello world")
144+
tag = gcm.final()
145+
assert len(tag) == good
146+
147+
def test_repeated_construction_destruction():
148+
import gc
149+
key = "fedcba9876543210"
150+
iv = "0123456789abcdef"
151+
for _ in range(1000):
152+
gcm = AesGcmStream(key, iv)
153+
gcm.encrypt("hello world")
154+
gcm.final()
155+
del gcm
156+
gc.collect()

tests/test_ciphers.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,13 @@ def test_des3_rejects_mode_ctr():
879879
key = b"\x01\x23\x45\x67\x89\xab\xcd\xef" * 3
880880
iv = b"\xfe\xdc\xba\x98\x76\x54\x32\x10"
881881
with pytest.raises(ValueError, match="Des3 only supports MODE_CBC"):
882-
Des3(key, MODE_CTR, iv)
882+
Des3.new(key, MODE_CTR, iv)
883+
884+
def test_des3_rejects_mode_ecb():
885+
key = b"\x01\x23\x45\x67\x89\xab\xcd\xef" * 3
886+
iv = b"\xfe\xdc\xba\x98\x76\x54\x32\x10"
887+
with pytest.raises(ValueError, match="Des3 only supports MODE_CBC"):
888+
Des3.new(key, MODE_ECB, iv)
883889

884890

885891
if _lib.CHACHA_ENABLED:
@@ -898,3 +904,15 @@ def test_chacha_non_block_aligned():
898904
def test_chacha_invalid_key_length():
899905
with pytest.raises(ValueError, match="key must be"):
900906
ChaCha(b"\x00" * 20)
907+
908+
909+
if _lib.RSA_ENABLED:
910+
def test_encrypt_oaep_requires_hash_type(vectors):
911+
rsa = RsaPublic(vectors[RsaPublic].key)
912+
with pytest.raises(WolfCryptError, match="Hash type not set"):
913+
rsa.encrypt_oaep(b"plaintext")
914+
915+
def test_decrypt_oaep_requires_hash_type(vectors):
916+
rsa = RsaPrivate(vectors[RsaPrivate].key)
917+
with pytest.raises(WolfCryptError, match="Hash type not set"):
918+
rsa.decrypt_oaep(b"\x00" * rsa.output_size)

tests/test_hashes.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,13 @@ def test_hash(hash_cls, vectors):
184184
copy.update("wolfcrypt")
185185

186186
assert hash_obj.hexdigest() == copy.hexdigest() == digest
187+
188+
189+
def test_hash_repeated_construction_destruction(hash_cls, vectors):
190+
import gc
191+
digest = vectors[hash_cls].digest
192+
for _ in range(1000):
193+
h = hash_new(hash_cls, "wolfcrypt")
194+
assert h.hexdigest() == digest
195+
del h
196+
gc.collect()

wolfcrypt/asn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ def pem_to_der(pem, pem_type):
4444
err = "Error converting from PEM to DER. ({})".format(ret)
4545
raise WolfCryptError(err)
4646

47-
result = _ffi.buffer(der[0][0].buffer, der[0][0].length)[:]
48-
_lib.wc_FreeDer(der)
47+
try:
48+
result = _ffi.buffer(der[0][0].buffer, der[0][0].length)[:]
49+
finally:
50+
_lib.wc_FreeDer(der)
4951
return result
5052

5153
def der_to_pem(der, pem_type):

wolfcrypt/ciphers.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -396,33 +396,39 @@ class AesGcmStream(object):
396396
block_size = 16
397397
_key_sizes = [16, 24, 32]
398398
_native_type = "Aes *"
399-
_aad = bytes()
400-
_tag_bytes = 16
401-
_mode = None
399+
_delete = _lib.wc_AesFree
402400

403401
def __init__(self, key, IV, tag_bytes=16):
404402
"""
405403
tag_bytes is the number of bytes to use for the authentication tag during encryption
406404
"""
407405
key = t2b(key)
408406
IV = t2b(IV)
409-
if tag_bytes < 4 or tag_bytes > 16:
410-
raise ValueError("tag_bytes must be between 4 and 16")
407+
# NIST SP 800-38D valid GCM tag lengths: 16, 15, 14, 13, 12, 8, 4 bytes.
408+
if tag_bytes not in (4, 8, 12, 13, 14, 15, 16):
409+
raise ValueError(
410+
"tag_bytes must be one of 4, 8, 12, 13, 14, 15, or 16")
411+
# Per-instance state: AAD, tag length, and current mode (enc/dec).
412+
self._aad = bytes()
411413
self._tag_bytes = tag_bytes
414+
self._mode = None
412415
if len(key) not in self._key_sizes:
413416
raise ValueError("key must be %s in length, not %d" %
414417
(self._key_sizes, len(key)))
418+
self._init_done = False
415419
self._native_object = _ffi.new(self._native_type)
416420
ret = _lib.wc_AesInit(self._native_object, _ffi.NULL, -2)
417421
if ret < 0:
418422
raise WolfCryptError("AES init error (%d)" % ret)
423+
self._init_done = True
419424
ret = _lib.wc_AesGcmInit(self._native_object, key, len(key), IV, len(IV))
420425
if ret < 0:
421426
raise WolfCryptError("Init error (%d)" % ret)
422427

423428
def __del__(self):
424-
if hasattr(self, '_native_object'):
425-
_lib.wc_AesFree(self._native_object)
429+
if getattr(self, '_init_done', False):
430+
self._delete(self._native_object)
431+
self._init_done = False
426432

427433
def set_aad(self, data):
428434
"""
@@ -446,11 +452,11 @@ def encrypt(self, data):
446452
aad = self._aad
447453
elif self._mode == _DECRYPTION:
448454
raise WolfCryptError("Class instance already in use for decryption")
449-
self._buf = _ffi.new("byte[%d]" % (len(data)))
450-
ret = _lib.wc_AesGcmEncryptUpdate(self._native_object, self._buf, data, len(data), aad, len(aad))
455+
buf = _ffi.new("byte[%d]" % (len(data)))
456+
ret = _lib.wc_AesGcmEncryptUpdate(self._native_object, buf, data, len(data), aad, len(aad))
451457
if ret < 0:
452458
raise WolfCryptError("Encryption error (%d)" % ret)
453-
return bytes(self._buf)
459+
return bytes(buf)
454460

455461
def decrypt(self, data):
456462
"""
@@ -463,11 +469,11 @@ def decrypt(self, data):
463469
aad = self._aad
464470
elif self._mode == _ENCRYPTION:
465471
raise WolfCryptError("Class instance already in use for encryption")
466-
self._buf = _ffi.new("byte[%d]" % (len(data)))
467-
ret = _lib.wc_AesGcmDecryptUpdate(self._native_object, self._buf, data, len(data), aad, len(aad))
472+
buf = _ffi.new("byte[%d]" % (len(data)))
473+
ret = _lib.wc_AesGcmDecryptUpdate(self._native_object, buf, data, len(data), aad, len(aad))
468474
if ret < 0:
469475
raise WolfCryptError("Decryption error (%d)" % ret)
470-
return bytes(self._buf)
476+
return bytes(buf)
471477

472478
def final(self, authTag=None):
473479
"""
@@ -505,7 +511,9 @@ class ChaCha(_Cipher):
505511
_IV_nonce = b""
506512
_IV_counter = 0
507513

508-
def __init__(self, key="", size=32):
514+
def __init__(self, key="", size=32): # pylint: disable=unused-argument
515+
# size is kept for backwards compatibility; key length is now
516+
# derived from the actual key and validated against _key_sizes.
509517
self._native_object = _ffi.new(self._native_type)
510518
self._enc = None
511519
self._dec = None
@@ -552,7 +560,9 @@ def set_iv(self, nonce, counter = 0):
552560
raise ValueError("nonce must be %d bytes, got %d" %
553561
(self._NONCE_SIZE, len(self._IV_nonce)))
554562
self._IV_counter = counter
555-
self._set_key(0)
563+
ret = self._set_key(0)
564+
if ret < 0:
565+
raise WolfCryptError("ChaCha set_iv error (%d)" % ret)
556566

557567
if _lib.CHACHA20_POLY1305_ENABLED:
558568
class ChaCha20Poly1305(object):
@@ -864,6 +874,9 @@ def make_key(cls, size, rng=None, hash_type=None):
864874
if rsa.output_size <= 0: # pragma: no cover
865875
raise WolfCryptError("Invalid key size error (%d)" % ret)
866876

877+
# Retain RNG reference defensively.
878+
rsa._rng = rng
879+
867880
return rsa
868881

869882
def __init__(self, key=None, hash_type=None): # pylint: disable=super-init-not-called
@@ -1231,7 +1244,11 @@ def make_key(cls, size, rng=None):
12311244
ret = _lib.wc_ecc_set_rng(ecc.native_object, rng.native_object)
12321245
if ret < 0:
12331246
raise WolfCryptError("Error setting ECC RNG (%d)" % ret)
1234-
ecc._rng = rng
1247+
1248+
# Retain the RNG so it outlives the ECC key. Even outside the
1249+
# timing-resistance path, wolfSSL internals may retain a pointer
1250+
# to the RNG; keeping the reference avoids any UAF risk.
1251+
ecc._rng = rng
12351252

12361253
return ecc
12371254

@@ -1504,6 +1521,10 @@ def make_key(cls, size, rng=None):
15041521
if ret < 0:
15051522
raise WolfCryptError("Key generation error (%d)" % ret)
15061523

1524+
# Retain RNG reference defensively; wolfSSL may retain a pointer
1525+
# internally on some builds.
1526+
ed25519._rng = rng
1527+
15071528
return ed25519
15081529

15091530
def decode_key(self, key, pub = None):
@@ -1706,6 +1727,10 @@ def make_key(cls, size, rng=None):
17061727
if ret < 0:
17071728
raise WolfCryptError("Key generation error (%d)" % ret)
17081729

1730+
# Retain RNG reference defensively; wolfSSL may retain a pointer
1731+
# internally on some builds.
1732+
ed448._rng = rng
1733+
17091734
return ed448
17101735

17111736
def decode_key(self, key, pub = None):
@@ -1979,6 +2004,9 @@ def make_key(cls, mlkem_type, rng=None):
19792004
if ret < 0: # pragma: no cover
19802005
raise WolfCryptError("wc_KyberKey_MakeKey() error (%d)" % ret)
19812006

2007+
# Retain RNG reference defensively.
2008+
mlkem_priv._rng = rng
2009+
19822010
return mlkem_priv
19832011

19842012
@classmethod
@@ -2226,6 +2254,9 @@ def make_key(cls, mldsa_type, rng=None):
22262254
if ret < 0: # pragma: no cover
22272255
raise WolfCryptError("wc_dilithium_make_key() error (%d)" % ret)
22282256

2257+
# Retain RNG reference defensively.
2258+
mldsa_priv._rng = rng
2259+
22292260
return mldsa_priv
22302261

22312262
@property

0 commit comments

Comments
 (0)