Skip to content

Commit b070887

Browse files
authored
Merge pull request #93 from roberthdevries/rsa-public-add-rng-param
Make the random generator of _Rsa and RsaPublic configurable.
2 parents c17f7f0 + 377542f commit b070887

2 files changed

Lines changed: 70 additions & 11 deletions

File tree

tests/test_ciphers.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from wolfcrypt.ciphers import MODE_CTR, MODE_ECB, MODE_CBC, WolfCryptError
2828
from wolfcrypt.random import Random
2929
from wolfcrypt.utils import t2b, h2b
30+
from wolfcrypt.random import Random
3031
import os
3132

3233
certs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "certs")
@@ -326,10 +327,18 @@ def test_chacha_enc_dec(chacha_obj):
326327
assert plaintext == dec
327328

328329
if _lib.RSA_ENABLED:
330+
@pytest.fixture
331+
def rng():
332+
return Random()
333+
329334
@pytest.fixture
330335
def rsa_private(vectors):
331336
return RsaPrivate(vectors[RsaPrivate].key)
332337

338+
@pytest.fixture
339+
def rsa_private_rng(vectors, rng):
340+
return RsaPrivate(vectors[RsaPrivate].key, rng=rng)
341+
333342
@pytest.fixture
334343
def rsa_private_oaep(vectors):
335344
return RsaPrivate(vectors[RsaPrivate].key, hash_type=HASH_TYPE_SHA)
@@ -346,6 +355,10 @@ def rsa_private_pkcs8(vectors):
346355
def rsa_public(vectors):
347356
return RsaPublic(vectors[RsaPublic].key)
348357

358+
@pytest.fixture
359+
def rsa_public_rng(vectors, rng):
360+
return RsaPublic(vectors[RsaPublic].key, rng=rng)
361+
349362
@pytest.fixture
350363
def rsa_public_oaep(vectors):
351364
return RsaPublic(vectors[RsaPublic].key, hash_type=HASH_TYPE_SHA)
@@ -366,6 +379,17 @@ def rsa_public_pem(vectors):
366379
pem = f.read()
367380
return RsaPublic.from_pem(pem)
368381

382+
@pytest.fixture
383+
def rsa_private_pem_rng(vectors, rng):
384+
with open(vectors[RsaPrivate].pem, "rb") as f:
385+
pem = f.read()
386+
return RsaPrivate.from_pem(pem, rng=rng)
387+
388+
@pytest.fixture
389+
def rsa_public_pem_rng(vectors, rng):
390+
with open(vectors[RsaPublic].pem, "rb") as f:
391+
pem = f.read()
392+
return RsaPublic.from_pem(pem, rng=rng)
369393

370394
def test_new_rsa_raises(vectors):
371395
with pytest.raises(WolfCryptError):
@@ -395,6 +419,22 @@ def test_rsa_encrypt_decrypt(rsa_private, rsa_public):
395419
assert 1024 / 8 == len(ciphertext) == rsa_private.output_size
396420
assert plaintext == rsa_private.decrypt(ciphertext)
397421

422+
def test_rsa_encrypt_decrypt_rng(rsa_private_rng, rsa_public_rng):
423+
plaintext = t2b("Everyone gets Friday off.")
424+
425+
# normal usage, encrypt with public, decrypt with private
426+
ciphertext = rsa_public_rng.encrypt(plaintext)
427+
428+
assert 1024 / 8 == len(ciphertext) == rsa_public_rng.output_size
429+
assert plaintext == rsa_private_rng.decrypt(ciphertext)
430+
431+
# private object holds both private and public info, so it can also encrypt
432+
# using the known public key.
433+
ciphertext = rsa_private_rng.encrypt(plaintext)
434+
435+
assert 1024 / 8 == len(ciphertext) == rsa_private_rng.output_size
436+
assert plaintext == rsa_private_rng.decrypt(ciphertext)
437+
398438
def test_rsa_encrypt_decrypt_pad_oaep(rsa_private_oaep, rsa_public_oaep):
399439
plaintext = t2b("Everyone gets Friday off.")
400440

@@ -478,6 +518,22 @@ def test_rsa_sign_verify_pem(rsa_private_pem, rsa_public_pem):
478518
assert 256 == len(signature) == rsa_private_pem.output_size
479519
assert plaintext == rsa_private_pem.verify(signature)
480520

521+
def test_rsa_sign_verify_pem_rng(rsa_private_pem_rng, rsa_public_pem_rng):
522+
plaintext = t2b("Everyone gets Friday off.")
523+
524+
# normal usage, sign with private, verify with public
525+
signature = rsa_private_pem_rng.sign(plaintext)
526+
527+
assert 256 == len(signature) == rsa_private_pem_rng.output_size
528+
assert plaintext == rsa_public_pem_rng.verify(signature)
529+
530+
# private object holds both private and public info, so it can also verify
531+
# using the known public key.
532+
signature = rsa_private_pem_rng.sign(plaintext)
533+
534+
assert 256 == len(signature) == rsa_private_pem_rng.output_size
535+
assert plaintext == rsa_private_pem_rng.verify(signature)
536+
481537
def test_rsa_pkcs8_sign_verify(rsa_private_pkcs8, rsa_public):
482538
plaintext = t2b("Everyone gets Friday off.")
483539

wolfcrypt/ciphers.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -684,13 +684,16 @@ class _Rsa: # pylint: disable=too-few-public-methods
684684
_mgf = None
685685
_hash_type = None
686686

687-
def __init__(self):
687+
def __init__(self, rng=None):
688+
if rng is None:
689+
rng = Random()
690+
688691
self.native_object = _ffi.new("RsaKey *")
689692
ret = _lib.wc_InitRsaKey(self.native_object, _ffi.NULL)
690693
if ret < 0: # pragma: no cover
691694
raise WolfCryptError("Invalid key error (%d)" % ret)
692695

693-
self._random = Random()
696+
self._random = rng
694697
if _lib.RSA_BLINDING_ENABLED:
695698
ret = _lib.wc_RsaSetRNG(self.native_object,
696699
self._random.native_object)
@@ -724,13 +727,13 @@ def _get_mgf(self):
724727

725728

726729
class RsaPublic(_Rsa):
727-
def __init__(self, key=None, hash_type=None):
730+
def __init__(self, key=None, hash_type=None, rng=None):
731+
super().__init__(rng)
732+
728733
if key is not None:
729734
key = t2b(key)
730735
self._hash_type = hash_type
731736

732-
_Rsa.__init__(self)
733-
734737
idx = _ffi.new("word32*")
735738
idx[0] = 0
736739

@@ -747,9 +750,9 @@ def __init__(self, key=None, hash_type=None):
747750

748751
if _lib.ASN_ENABLED:
749752
@classmethod
750-
def from_pem(cls, file, hash_type=None):
753+
def from_pem(cls, file, hash_type=None, rng=None):
751754
der = pem_to_der(file, _lib.PUBLICKEY_TYPE)
752-
return cls(key=der, hash_type=hash_type)
755+
return cls(key=der, hash_type=hash_type, rng=rng)
753756

754757
def encrypt(self, plaintext):
755758
"""
@@ -883,9 +886,9 @@ def make_key(cls, size, rng=None, hash_type=None):
883886

884887
return rsa
885888

886-
def __init__(self, key=None, hash_type=None): # pylint: disable=super-init-not-called
889+
def __init__(self, key=None, hash_type=None, rng=None): # pylint: disable=super-init-not-called
887890

888-
_Rsa.__init__(self) # pylint: disable=non-parent-init-called
891+
_Rsa.__init__(self, rng) # pylint: disable=non-parent-init-called
889892
self._hash_type = hash_type
890893
idx = _ffi.new("word32*")
891894
idx[0] = 0
@@ -913,9 +916,9 @@ def __init__(self, key=None, hash_type=None): # pylint: disable=super-init-not-
913916

914917
if _lib.ASN_ENABLED:
915918
@classmethod
916-
def from_pem(cls, file, hash_type=None):
919+
def from_pem(cls, file, hash_type=None, rng=None):
917920
der = pem_to_der(file, _lib.PRIVATEKEY_TYPE)
918-
return cls(key=der, hash_type=hash_type)
921+
return cls(key=der, hash_type=hash_type, rng=rng)
919922

920923
if _lib.KEYGEN_ENABLED:
921924
def encode_key(self):

0 commit comments

Comments
 (0)