@@ -4,12 +4,74 @@ use halo2_proofs::{
44 dev:: MockProver ,
55 pasta:: Fp ,
66} ;
7- use ff:: PrimeField ;
7+ use ff:: { PrimeField , Field } ;
88use serde:: { Serialize , Deserialize } ;
99
1010mod circuit;
1111use circuit:: LoRACircuit ;
1212
13+ // Constants for quantization
14+ const SCALE_FACTOR : u64 = 10_000 ; // 10^4 for 4 decimal places
15+ const SCALE_FACTOR_F64 : f64 = SCALE_FACTOR as f64 ;
16+
17+ /// Convert a floating-point value to a field element using fixed-point arithmetic
18+ fn quantize_to_field < F : PrimeField > ( value : f64 ) -> F {
19+ // Handle special cases
20+ if value == 0.0 {
21+ return F :: zero ( ) ;
22+ }
23+ if value == 1.0 {
24+ return F :: one ( ) ;
25+ }
26+ if value == -1.0 {
27+ return -F :: one ( ) ;
28+ }
29+
30+ // Take absolute value and scale
31+ let abs_val = value. abs ( ) ;
32+ let scaled = ( abs_val * SCALE_FACTOR_F64 ) . round ( ) as u64 ;
33+
34+ // Convert to field element
35+ let field_val = F :: from ( scaled) ;
36+
37+ // Apply sign
38+ if value < 0.0 {
39+ -field_val
40+ } else {
41+ field_val
42+ }
43+ }
44+
45+ /// Convert a field element back to a floating-point value
46+ fn dequantize_from_field < F : PrimeField > ( value : F ) -> f64 {
47+ // Handle special cases
48+ if value == F :: zero ( ) {
49+ return 0.0 ;
50+ }
51+ if value == F :: one ( ) {
52+ return 1.0 ;
53+ }
54+ if value == -F :: one ( ) {
55+ return -1.0 ;
56+ }
57+
58+ // Convert to bytes and then to u64
59+ let bytes = value. to_repr ( ) ;
60+ let mut u64_bytes = [ 0u8 ; 8 ] ;
61+ u64_bytes. copy_from_slice ( & bytes[ 0 ..8 ] ) ;
62+ let scaled = u64:: from_le_bytes ( u64_bytes) ;
63+
64+ // Dequantize by dividing by scale factor
65+ let dequantized = scaled as f64 / SCALE_FACTOR_F64 ;
66+
67+ // Handle negative values
68+ if value < F :: zero ( ) {
69+ -dequantized
70+ } else {
71+ dequantized
72+ }
73+ }
74+
1375#[ derive( Serialize , Deserialize ) ]
1476struct ProofData {
1577 input : Vec < f64 > ,
@@ -35,17 +97,17 @@ fn generate_proof(
3597
3698 // Calculate expected output for MockProver based on circuit's logic
3799 let i_val = if !input. is_empty ( ) {
38- Fp :: from ( input[ 0 ] . abs ( ) as u64 )
100+ quantize_to_field ( input[ 0 ] )
39101 } else {
40102 Fp :: zero ( )
41103 } ;
42104 let wa_val = if !weight_a. is_empty ( ) {
43- Fp :: from ( weight_a[ 0 ] . abs ( ) as u64 )
105+ quantize_to_field ( weight_a[ 0 ] )
44106 } else {
45107 Fp :: one ( )
46108 } ;
47109 let wb_val = if !weight_b. is_empty ( ) {
48- Fp :: from ( weight_b[ 0 ] . abs ( ) as u64 )
110+ quantize_to_field ( weight_b[ 0 ] )
49111 } else {
50112 Fp :: one ( )
51113 } ;
@@ -99,25 +161,25 @@ fn verify_proof(proof: &[u8], public_inputs: Vec<f64>) -> PyResult<bool> {
99161 // Calculate expected output from public inputs or use proof data
100162 let expected_public_output = if !public_inputs. is_empty ( ) {
101163 // Use provided public inputs
102- Fp :: from ( public_inputs[ 0 ] . abs ( ) as u64 )
164+ quantize_to_field ( public_inputs[ 0 ] )
103165 } else {
104166 // Use expected output from proof data when no public inputs provided
105167 Fp :: from ( proof_data. expected_output )
106168 } ;
107169
108170 // Calculate the actual output based on the circuit computation
109171 let i_val = if !proof_data. input . is_empty ( ) {
110- Fp :: from ( proof_data. input [ 0 ] . abs ( ) as u64 )
172+ quantize_to_field ( proof_data. input [ 0 ] )
111173 } else {
112174 Fp :: zero ( )
113175 } ;
114176 let wa_val = if !proof_data. weight_a . is_empty ( ) {
115- Fp :: from ( proof_data. weight_a [ 0 ] . abs ( ) as u64 )
177+ quantize_to_field ( proof_data. weight_a [ 0 ] )
116178 } else {
117179 Fp :: one ( )
118180 } ;
119181 let wb_val = if !proof_data. weight_b . is_empty ( ) {
120- Fp :: from ( proof_data. weight_b [ 0 ] . abs ( ) as u64 )
182+ quantize_to_field ( proof_data. weight_b [ 0 ] )
121183 } else {
122184 Fp :: one ( )
123185 } ;
@@ -152,6 +214,44 @@ fn zklora_halo2(_py: Python, m: &PyModule) -> PyResult<()> {
152214mod tests {
153215 use super :: * ;
154216
217+ #[ test]
218+ fn test_quantization_roundtrip ( ) {
219+ let test_values = vec ! [
220+ 0.0 , 1.0 , -1.0 ,
221+ 0.1234 , -0.1234 ,
222+ 123.456 , -123.456 ,
223+ 0.0001 , -0.0001 ,
224+ ] ;
225+
226+ for & value in test_values. iter ( ) {
227+ let field_val: Fp = quantize_to_field ( value) ;
228+ let roundtrip = dequantize_from_field ( field_val) ;
229+
230+ // Check that roundtrip preserves value within epsilon
231+ let epsilon = 1e-4 ; // Based on SCALE_FACTOR
232+ assert ! ( ( value - roundtrip) . abs( ) < epsilon,
233+ "Roundtrip failed for {}: got {}" , value, roundtrip) ;
234+ }
235+ }
236+
237+ #[ test]
238+ fn test_special_values ( ) {
239+ // Test zero
240+ let zero_field: Fp = quantize_to_field ( 0.0 ) ;
241+ assert_eq ! ( zero_field, Fp :: zero( ) ) ;
242+ assert_eq ! ( dequantize_from_field( zero_field) , 0.0 ) ;
243+
244+ // Test one
245+ let one_field: Fp = quantize_to_field ( 1.0 ) ;
246+ assert_eq ! ( one_field, Fp :: one( ) ) ;
247+ assert_eq ! ( dequantize_from_field( one_field) , 1.0 ) ;
248+
249+ // Test negative one
250+ let neg_one_field: Fp = quantize_to_field ( -1.0 ) ;
251+ assert_eq ! ( neg_one_field, -Fp :: one( ) ) ;
252+ assert_eq ! ( dequantize_from_field( neg_one_field) , -1.0 ) ;
253+ }
254+
155255 #[ test]
156256 fn test_proof_generation_and_verification ( ) {
157257 Python :: with_gil ( |py| {
0 commit comments