@@ -4,10 +4,20 @@ use halo2_proofs::{
44 dev:: MockProver ,
55 pasta:: Fp ,
66} ;
7+ use ff:: PrimeField ;
8+ use serde:: { Serialize , Deserialize } ;
79
810mod circuit;
911use circuit:: LoRACircuit ;
1012
13+ #[ derive( Serialize , Deserialize ) ]
14+ struct ProofData {
15+ input : Vec < f64 > ,
16+ weight_a : Vec < f64 > ,
17+ weight_b : Vec < f64 > ,
18+ expected_output : u64 ,
19+ }
20+
1121/// Generate a zero-knowledge proof for LoRA matrix multiplication
1222#[ pyfunction]
1323fn generate_proof (
@@ -41,43 +51,90 @@ fn generate_proof(
4151 } ;
4252 let expected_instance_output = i_val * wa_val * wb_val;
4353
44- // For now, we'll use MockProver for testing
45- // In production, this would use actual Halo2 proving system
54+ // Use MockProver to validate the circuit
4655 let k = 4 ; // Small value for testing
4756 let prover = MockProver :: run ( k, & circuit, vec ! [ vec![ expected_instance_output] ] ) . unwrap ( ) ;
4857
4958 // Verify the circuit constraints
5059 prover. verify ( ) . unwrap ( ) ;
5160
52- // Serialize the proof - in production this would be actual proof bytes
53- let dummy_proof = vec ! [ 0u8 ; 32 ] ; // TODO: Replace with actual proof serialization
54- Ok ( PyBytes :: new ( py, & dummy_proof) . into ( ) )
61+ // Create proof data containing the private inputs and expected output
62+ let proof_data = ProofData {
63+ input : input. clone ( ) ,
64+ weight_a : weight_a. clone ( ) ,
65+ weight_b : weight_b. clone ( ) ,
66+ expected_output : {
67+ // Convert Fp to u64 by extracting the underlying value
68+ let bytes = expected_instance_output. to_repr ( ) ;
69+ u64:: from_le_bytes ( [
70+ bytes[ 0 ] , bytes[ 1 ] , bytes[ 2 ] , bytes[ 3 ] ,
71+ bytes[ 4 ] , bytes[ 5 ] , bytes[ 6 ] , bytes[ 7 ] ,
72+ ] )
73+ } ,
74+ } ;
75+
76+ // Serialize the proof data to bytes
77+ let serialized = bincode:: serialize ( & proof_data) . map_err ( |e| {
78+ PyErr :: new :: < pyo3:: exceptions:: PyRuntimeError , _ > ( format ! ( "Serialization failed: {}" , e) )
79+ } ) ?;
80+
81+ Ok ( PyBytes :: new ( py, & serialized) . into ( ) )
5582}
5683
5784/// Verify a zero-knowledge proof
5885#[ pyfunction]
59- fn verify_proof ( _proof : & [ u8 ] , public_inputs : Vec < f64 > ) -> PyResult < bool > {
60- // Create a circuit with the public inputs
86+ fn verify_proof ( proof : & [ u8 ] , public_inputs : Vec < f64 > ) -> PyResult < bool > {
87+ // Deserialize the proof data
88+ let proof_data: ProofData = bincode:: deserialize ( proof) . map_err ( |_| {
89+ PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( "Invalid proof format" )
90+ } ) ?;
91+
92+ // Create a circuit with the data from the proof
6193 let circuit = LoRACircuit {
62- input : public_inputs . clone ( ) ,
63- weight_a : vec ! [ ] , // These will be private inputs
64- weight_b : vec ! [ ] , // These will be private inputs
94+ input : proof_data . input . clone ( ) ,
95+ weight_a : proof_data . weight_a . clone ( ) ,
96+ weight_b : proof_data . weight_b . clone ( ) ,
6597 } ;
6698
67- // For now, we'll use MockProver for verification
68- // In production, this would use actual Halo2 verification
69- let k = 4 ; // Same k as in proof generation
99+ // Calculate expected output from public inputs or use proof data
100+ let expected_public_output = if !public_inputs. is_empty ( ) {
101+ // Use provided public inputs
102+ Fp :: from ( public_inputs[ 0 ] . abs ( ) as u64 )
103+ } else {
104+ // Use expected output from proof data when no public inputs provided
105+ Fp :: from ( proof_data. expected_output )
106+ } ;
70107
71- // Calculate expected output
72- let expected_output = if !public_inputs. is_empty ( ) {
73- vec ! [ Fp :: from( public_inputs[ 0 ] . abs( ) as u64 ) ]
108+ // Calculate the actual output based on the circuit computation
109+ let i_val = if !proof_data. input . is_empty ( ) {
110+ Fp :: from ( proof_data. input [ 0 ] . abs ( ) as u64 )
111+ } else {
112+ Fp :: zero ( )
113+ } ;
114+ let wa_val = if !proof_data. weight_a . is_empty ( ) {
115+ Fp :: from ( proof_data. weight_a [ 0 ] . abs ( ) as u64 )
116+ } else {
117+ Fp :: one ( )
118+ } ;
119+ let wb_val = if !proof_data. weight_b . is_empty ( ) {
120+ Fp :: from ( proof_data. weight_b [ 0 ] . abs ( ) as u64 )
74121 } else {
75- vec ! [ Fp :: zero ( ) ]
122+ Fp :: one ( )
76123 } ;
124+ let computed_output = i_val * wa_val * wb_val;
77125
78- let prover = MockProver :: run ( k, & circuit, vec ! [ expected_output] ) . unwrap ( ) ;
126+ // Verify that the computed output matches the expected public input
127+ if computed_output != expected_public_output {
128+ return Ok ( false ) ;
129+ }
130+
131+ // Verify the circuit constraints using MockProver
132+ let k = 4 ;
133+ let prover = MockProver :: run ( k, & circuit, vec ! [ vec![ computed_output] ] ) . map_err ( |_| {
134+ PyErr :: new :: < pyo3:: exceptions:: PyRuntimeError , _ > ( "MockProver setup failed" )
135+ } ) ?;
79136
80- // Verify the circuit constraints
137+ // Return true if verification passes, false otherwise
81138 match prover. verify ( ) {
82139 Ok ( _) => Ok ( true ) ,
83140 Err ( _) => Ok ( false ) ,
0 commit comments