Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
#include "barretenberg/dsl/acir_format/witness_constant.hpp"
#include "barretenberg/stdlib/primitives/curves/secp256k1.hpp"
#include "barretenberg/stdlib/primitives/curves/secp256r1.hpp"
#include "barretenberg/ultra_honk/ultra_prover.hpp"
#include "barretenberg/ultra_honk/ultra_verifier.hpp"

#include <algorithm>
#include <gtest/gtest.h>
#include <memory>
#include <vector>

using namespace bb;
Expand All @@ -31,7 +34,7 @@ template <class Curve> class EcdsaTestingFunctions {
ZeroR, // Set R=0 (tests ECDSA validation)
ZeroS, // Set S=0 (tests ECDSA validation)
HighS, // Set S=high (tests malleability protection)
P, // Invalidate public key
P, // Make public key fail the curve equation
Result // Invalid signature with claimed valid result
};

Expand All @@ -43,7 +46,8 @@ template <class Curve> class EcdsaTestingFunctions {

static std::vector<std::string> get_labels()
{
return { "None", "Hash is not a byte array", "Zero R", "Zero S", "High S", "Public key", "Result" };
return { "None", "Hash is not a byte array", "Zero R", "Zero S",
"High S", "Public key not on curve", "Result" };
}
};

Expand All @@ -53,6 +57,20 @@ template <class Curve> class EcdsaTestingFunctions {

static ProgramMetadata generate_metadata() { return ProgramMetadata{}; }

static std::pair<AcirConstraint, WitnessVector> generate_invalid_verification_result_constraints(
const InvalidWitness::Target& invalid_witness_target)
{
AcirConstraint ecdsa_constraint;
WitnessVector witness_values;
generate_constraints(ecdsa_constraint, witness_values);

auto [invalid_constraint, invalid_witness_values] =
invalidate_witness(ecdsa_constraint, witness_values, invalid_witness_target);

invalid_witness_values[invalid_constraint.result] = bb::fr(0);
return { invalid_constraint, invalid_witness_values };
}

static std::pair<AcirConstraint, WitnessVector> invalidate_witness(
AcirConstraint ecdsa_constraints,
WitnessVector witness_values,
Expand Down Expand Up @@ -94,7 +112,7 @@ template <class Curve> class EcdsaTestingFunctions {
};
break;
case InvalidWitness::Target::P:
// Invalidate public key
// Invalidate public key so signature verification returns false.
witness_values[ecdsa_constraints.pub_x_indices[0]] += bb::fr(1);
break;
case InvalidWitness::Target::Result:
Expand Down Expand Up @@ -169,6 +187,24 @@ template <class Curve> class EcdsaTestingFunctions {
}
};

template <typename Flavor> bool construct_and_verify_honk_proof(typename Flavor::CircuitBuilder& builder)
{
using Prover = UltraProver_<Flavor>;
using Verifier = UltraVerifier_<Flavor, DefaultIO>;
using ProverInstance = ProverInstance_<Flavor>;
using VerificationKey = typename Flavor::VerificationKey;

auto prover_instance = std::make_shared<ProverInstance>(builder);
auto verification_key = std::make_shared<VerificationKey>(prover_instance->get_precomputed());
auto vk_and_hash = std::make_shared<typename Flavor::VKAndHash>(verification_key);

Prover prover(prover_instance, verification_key);
auto proof = prover.construct_proof();

Verifier verifier(vk_and_hash);
return verifier.verify_proof(proof).result;
}

template <class Curve>
class EcdsaConstraintsTest : public ::testing::Test, public TestClassWithPredicate<EcdsaTestingFunctions<Curve>> {
protected:
Expand Down Expand Up @@ -220,3 +256,34 @@ TYPED_TEST(EcdsaConstraintsTest, InvalidWitnesses)
BB_DISABLE_ASSERTS();
[[maybe_unused]] std::vector<std::string> _ = TestFixture::test_invalid_witnesses();
}

TYPED_TEST(EcdsaConstraintsTest, InvalidVerificationInputsReturnFalseAndProve)
{
BB_DISABLE_ASSERTS();
using Builder = typename TypeParam::Builder;
using Flavor = std::conditional_t<std::is_same_v<Builder, UltraCircuitBuilder>, UltraFlavor, MegaFlavor>;
using InvalidWitnessTarget = typename TestFixture::InvalidWitnessTarget;

const std::vector<InvalidWitnessTarget> invalid_targets = {
InvalidWitnessTarget::ZeroR,
InvalidWitnessTarget::ZeroS,
InvalidWitnessTarget::P,
};
const std::vector<std::string> target_labels = { "zero r", "zero s", "public key not on curve" };

for (auto [invalid_target, target_label] : zip_view(invalid_targets, target_labels)) {
SCOPED_TRACE(target_label);

auto [constraint, witness_values] =
TestFixture::Base::generate_invalid_verification_result_constraints(invalid_target);
ASSERT_EQ(witness_values[constraint.result], bb::fr(0));

AcirFormat constraint_system = constraint_to_acir_format(constraint);
AcirProgram program{ constraint_system, witness_values };
auto builder = create_circuit<Builder>(program, TestFixture::Base::generate_metadata());

EXPECT_TRUE(CircuitChecker::check(builder));
EXPECT_FALSE(builder.failed()) << builder.err();
EXPECT_TRUE(construct_and_verify_honk_proof<Flavor>(builder));
}
}
Loading