Skip to content
This repository was archived by the owner on Mar 10, 2026. It is now read-only.

Commit 7172dbc

Browse files
authored
Merge pull request #4 from Freedom-Club-FC/refactor/add-tests
Add initial quantum and traditional crypto tests
2 parents 4f6aa3f + 454fa76 commit 7172dbc

File tree

2 files changed

+163
-0
lines changed

2 files changed

+163
-0
lines changed

tests/test_crypto.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# tests/test_crypto.py
2+
"""
3+
Tests for ML-KEM-1024 (Kyber) and ML-DSA-87 (Dilithium5).
4+
Covers:
5+
- Key generation conformance to NIST spec
6+
- OTP encryption using Kyber key exchange
7+
- Hash chain tamper detection
8+
"""
9+
10+
import pytest
11+
from core.crypto import (
12+
generate_kem_keys,
13+
generate_sign_keys,
14+
generate_kyber_shared_secrets,
15+
decrypt_kyber_shared_secrets,
16+
otp_encrypt_with_padding,
17+
otp_decrypt_with_padding
18+
)
19+
from core.constants import (
20+
OTP_PADDING_LIMIT,
21+
OTP_PADDING_LENGTH
22+
)
23+
from core.trad_crypto import sha3_512
24+
25+
# NIST-specified key sizes (bytes)
26+
ML_KEM_1024_SK_LEN = 3168
27+
ML_KEM_1024_PK_LEN = 1568
28+
ML_DSA_87_SK_LEN = 4864
29+
ML_DSA_87_PK_LEN = 2592
30+
HASH_SIZE = 64 # SHA3-512 output size in bytes
31+
32+
33+
def test_mlkem_keygen_basic():
34+
"""Validate ML-KEM-1024 key generation: uniqueness, type, and length."""
35+
seen_private_keys = set()
36+
seen_public_keys = set()
37+
38+
for _ in range(100):
39+
private_key, public_key = generate_kem_keys(algorithm="Kyber1024")
40+
41+
assert private_key not in seen_private_keys, "Duplicate private key detected"
42+
assert public_key not in seen_public_keys, "Duplicate public key detected"
43+
44+
assert private_key != public_key, "Private and public keys must differ"
45+
assert isinstance(private_key, bytes) and isinstance(public_key, bytes), "Keys must be bytes"
46+
assert len(private_key) == ML_KEM_1024_SK_LEN, "Private key length mismatch with spec"
47+
assert len(public_key) == ML_KEM_1024_PK_LEN, "Public key length mismatch with spec"
48+
49+
seen_private_keys.add(private_key)
50+
seen_public_keys.add(public_key)
51+
52+
53+
def test_mldsa_keygen_basic():
54+
"""Validate ML-DSA-87 key generation: uniqueness, type, and length."""
55+
seen_private_keys = set()
56+
seen_public_keys = set()
57+
58+
for _ in range(100):
59+
private_key, public_key = generate_sign_keys(algorithm="Dilithium5")
60+
61+
assert private_key not in seen_private_keys, "Duplicate private key detected"
62+
assert public_key not in seen_public_keys, "Duplicate public key detected"
63+
64+
assert private_key != public_key, "Private and public keys must differ"
65+
assert isinstance(private_key, bytes) and isinstance(public_key, bytes), "Keys must be bytes"
66+
assert len(private_key) == ML_DSA_87_SK_LEN, "Private key length mismatch with spec"
67+
assert len(public_key) == ML_DSA_87_PK_LEN, "Public key length mismatch with spec"
68+
69+
seen_private_keys.add(private_key)
70+
seen_public_keys.add(public_key)
71+
72+
73+
def test_kem_otp_encryption():
74+
"""Full Kyber OTP exchange and tamper detection test."""
75+
# Alice creates ephemeral ML-KEM-1024 keypair for PFS
76+
alice_private_key, alice_public_key = generate_kem_keys()
77+
78+
# Bob creates his own ephemeral keypair
79+
bob_private_key, bob_public_key = generate_kem_keys()
80+
81+
# Bob derives shared pads from Alice's public key
82+
ciphertext, bob_pads = generate_kyber_shared_secrets(alice_public_key)
83+
assert ciphertext != bob_pads, "Ciphertext equals pads (should differ)"
84+
85+
# First 64 bytes are hash chain seed
86+
bob_hash_chain_seed = bob_pads[:HASH_SIZE]
87+
88+
# Alice decrypts ciphertext to recover shared pads
89+
plaintext = decrypt_kyber_shared_secrets(ciphertext, alice_private_key)
90+
assert plaintext == bob_pads, "Pads mismatch after decryption"
91+
92+
# Bob encrypts a message using OTP with hash chain
93+
message = "Hello, World!"
94+
message_encoded = message.encode("utf-8")
95+
bob_next_hash_chain = sha3_512(bob_hash_chain_seed + message_encoded)
96+
message_encoded = bob_next_hash_chain + message_encoded
97+
98+
pad_len = max(0, OTP_PADDING_LIMIT - OTP_PADDING_LENGTH - len(message_encoded))
99+
otp_pad = bob_pads[:len(message_encoded) + OTP_PADDING_LENGTH + pad_len]
100+
encrypted = otp_encrypt_with_padding(message_encoded, otp_pad, padding_limit=pad_len)
101+
102+
assert encrypted != message_encoded, "Ciphertext equals plaintext"
103+
assert len(encrypted) == len(otp_pad), "Ciphertext length mismatch"
104+
105+
# Alice decrypts and validates hash chain
106+
decrypted = otp_decrypt_with_padding(encrypted, plaintext[:len(encrypted)])
107+
recv_hash = decrypted[:HASH_SIZE]
108+
recv_plaintext = decrypted[HASH_SIZE:]
109+
assert recv_plaintext.decode() == message, "Decrypted message mismatch"
110+
111+
calc_next_hash = sha3_512(bob_hash_chain_seed + recv_plaintext)
112+
assert calc_next_hash == recv_hash, "Hash chain verification failed"
113+
114+
# Tampering test: flip a byte
115+
tampered_message = bytearray(encrypted)
116+
tampered_message[65] ^= 0xFF
117+
118+
tampered_decrypted = otp_decrypt_with_padding(bytes(tampered_message), plaintext[:len(encrypted)])
119+
tampered_hash = tampered_decrypted[:HASH_SIZE]
120+
tampered_plaintext = tampered_decrypted[HASH_SIZE:]
121+
122+
calc_tampered_hash = sha3_512(bob_hash_chain_seed + tampered_plaintext)
123+
assert calc_tampered_hash != tampered_hash, "Tampering not detected"

tests/test_trad_crypto.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# tests/test_trad_crypto.py
2+
"""
3+
Tests for AES-256 GCM encryption/decryption and Argon2id key derivation.
4+
Focus: Correctness of encryption/decryption flow and tamper detection.
5+
"""
6+
7+
import pytest
8+
from core.trad_crypto import (
9+
encrypt_aes_gcm,
10+
decrypt_aes_gcm,
11+
derive_key_argon2id
12+
)
13+
14+
15+
def test_aes_encrypt_decrypt():
16+
# Test input data
17+
data = b"Hello, World!"
18+
password = b"Password123"
19+
20+
# Derive AES-256 key using Argon2id
21+
key, salt = derive_key_argon2id(password, output_length=32)
22+
assert key != salt, "Derived key should not equal derived salt"
23+
assert key != password, "Derived key should not match plaintext password"
24+
25+
# Encrypt plaintext using AES-GCM
26+
nonce, ciphertext = encrypt_aes_gcm(key, data)
27+
assert nonce != ciphertext, "Nonce and ciphertext should not be equal"
28+
assert ciphertext != data, "Ciphertext should differ from plaintext"
29+
30+
# Decrypt ciphertext and verify correctness
31+
plaintext = decrypt_aes_gcm(key, nonce, ciphertext)
32+
assert plaintext == data, "Decrypted plaintext does not match original"
33+
34+
# Tampering test: Modify ciphertext and expect decryption failure
35+
tampered_ciphertext = bytearray(ciphertext)
36+
tampered_ciphertext[-1] ^= 0xFF # Flip last byte to corrupt data
37+
38+
with pytest.raises(Exception):
39+
decrypt_aes_gcm(key, nonce, bytes(tampered_ciphertext))
40+

0 commit comments

Comments
 (0)