diff --git a/src/create_discriminant.h b/src/create_discriminant.h index 81408323..9a294bbd 100644 --- a/src/create_discriminant.h +++ b/src/create_discriminant.h @@ -4,6 +4,37 @@ #include "proof_common.h" integer CreateDiscriminant(std::vector& seed, int length = 1024) { + // INPUT VALIDATION - Fix for issue #282 + + // Check 1: Validate discriminant_size_bits is positive + if (length <= 0) { + throw std::invalid_argument( + "discriminant_size_bits must be positive (got " + + std::to_string(length) + ")" + ); + } + + // Check 2: Validate upper bound (optional but recommended) + const int MAX_DISCRIMINANT_SIZE_BITS = 16384; + if (length > MAX_DISCRIMINANT_SIZE_BITS) { + throw std::invalid_argument( + "discriminant_size_bits exceeds maximum allowed value" + ); + } + + // Check 3: Validate that length is a multiple of 8 (required by HashPrime) + if (length % 8 != 0) { + throw std::invalid_argument( + "discriminant_size_bits must be a multiple of 8 (got " + + std::to_string(length) + ")" + ); + } + + // Check 4: Validate seed is not empty + if (seed.empty()) { + throw std::invalid_argument("seed cannot be empty"); + } + return HashPrime(seed, length, {0, 1, 2, length - 1}) * integer(-1); } diff --git a/tests/test_discriminant_validation.py b/tests/test_discriminant_validation.py new file mode 100644 index 00000000..5e7015c5 --- /dev/null +++ b/tests/test_discriminant_validation.py @@ -0,0 +1,34 @@ +import secrets +import pytest + +from chiavdf import create_discriminant + + +def test_discriminant_size_bits_negative(): + """Test that discriminant_size_bits of -1 fails""" + discriminant_challenge = secrets.token_bytes(10) + with pytest.raises(ValueError, match="discriminant_size_bits must be positive"): + create_discriminant(discriminant_challenge, -1) + + +def test_discriminant_size_bits_zero(): + """Test that discriminant_size_bits of 0 fails""" + discriminant_challenge = secrets.token_bytes(10) + with pytest.raises(ValueError, match="discriminant_size_bits must be positive"): + create_discriminant(discriminant_challenge, 0) + + +def test_discriminant_size_bits_too_large(): + """Test that discriminant_size_bits of 16400 fails (exceeds max of 16384)""" + discriminant_challenge = secrets.token_bytes(10) + with pytest.raises(ValueError, match="discriminant_size_bits exceeds maximum allowed value"): + create_discriminant(discriminant_challenge, 16400) + + +def test_discriminant_size_bits_valid(): + """Test that discriminant_size_bits of 1024 succeeds""" + discriminant_challenge = secrets.token_bytes(10) + discriminant = create_discriminant(discriminant_challenge, 1024) + # If we get here without an exception, the test passes + assert discriminant is not None + assert len(discriminant) > 0 # Should return a string representation of the discriminant