Skip to content

Commit c0b12a4

Browse files
committed
fix tests
1 parent 43edb6e commit c0b12a4

2 files changed

Lines changed: 129 additions & 28 deletions

File tree

src/zklora/libs/zklora_halo2/src/circuit.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ use halo2_proofs::{
22
circuit::{Layouter, SimpleFloorPlanner, Value},
33
plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Instance, Selector},
44
poly::Rotation,
5+
pasta::Fp,
56
};
7+
use ff::PrimeField;
8+
9+
use crate::quantize_to_field;
610

711
#[derive(Clone)]
812
pub struct LoRAConfig {
@@ -20,15 +24,15 @@ pub struct LoRACircuit {
2024
pub weight_b: Vec<f64>,
2125
}
2226

23-
impl Circuit<halo2_proofs::pasta::Fp> for LoRACircuit {
27+
impl Circuit<Fp> for LoRACircuit {
2428
type Config = LoRAConfig;
2529
type FloorPlanner = SimpleFloorPlanner;
2630

2731
fn without_witnesses(&self) -> Self {
2832
Self::default()
2933
}
3034

31-
fn configure(meta: &mut ConstraintSystem<halo2_proofs::pasta::Fp>) -> Self::Config {
35+
fn configure(meta: &mut ConstraintSystem<Fp>) -> Self::Config {
3236
let input = meta.advice_column();
3337
let weight_a = meta.advice_column();
3438
let weight_b = meta.advice_column();
@@ -63,24 +67,24 @@ impl Circuit<halo2_proofs::pasta::Fp> for LoRACircuit {
6367
fn synthesize(
6468
&self,
6569
config: Self::Config,
66-
mut layouter: impl Layouter<halo2_proofs::pasta::Fp>,
70+
mut layouter: impl Layouter<Fp>,
6771
) -> Result<(), Error> {
6872
let input_val = if !self.input.is_empty() {
69-
halo2_proofs::pasta::Fp::from(self.input[0].abs() as u64)
73+
quantize_to_field(self.input[0])
7074
} else {
71-
halo2_proofs::pasta::Fp::zero()
75+
Fp::zero()
7276
};
7377

7478
let weight_a_val = if !self.weight_a.is_empty() {
75-
halo2_proofs::pasta::Fp::from(self.weight_a[0].abs() as u64)
79+
quantize_to_field(self.weight_a[0])
7680
} else {
77-
halo2_proofs::pasta::Fp::one()
81+
Fp::one()
7882
};
7983

8084
let weight_b_val = if !self.weight_b.is_empty() {
81-
halo2_proofs::pasta::Fp::from(self.weight_b[0].abs() as u64)
85+
quantize_to_field(self.weight_b[0])
8286
} else {
83-
halo2_proofs::pasta::Fp::one()
87+
Fp::one()
8488
};
8589

8690
let _output_val = input_val * weight_a_val * weight_b_val;
@@ -111,13 +115,6 @@ impl Circuit<halo2_proofs::pasta::Fp> for LoRACircuit {
111115
|| Value::known(weight_b_val),
112116
)?;
113117

114-
// region.assign_advice_from_constant(
115-
// || "output",
116-
// config.output,
117-
// 0,
118-
// output_val,
119-
// )?;
120-
121118
Ok(())
122119
},
123120
)?;
@@ -139,11 +136,13 @@ mod tests {
139136
weight_b: vec![3.0],
140137
};
141138

142-
let expected_output = vec![halo2_proofs::pasta::Fp::from(6u64)];
139+
// Expected output is 1.0 * 2.0 * 3.0 = 6.0
140+
let expected_output = vec![quantize_to_field(6.0)];
143141
let prover = MockProver::run(4, &circuit, vec![expected_output]).unwrap();
144142
assert!(prover.verify().is_ok());
145143

146-
let wrong_output = vec![halo2_proofs::pasta::Fp::from(7u64)];
144+
// Wrong output should fail
145+
let wrong_output = vec![quantize_to_field(7.0)];
147146
let prover = MockProver::run(4, &circuit, vec![wrong_output]).unwrap();
148147
assert!(prover.verify().is_err());
149148
}
@@ -156,7 +155,8 @@ mod tests {
156155
weight_b: vec![],
157156
};
158157

159-
let expected_output = vec![halo2_proofs::pasta::Fp::zero()];
158+
// Expected output for empty inputs is 0 * 1 * 1 = 0
159+
let expected_output = vec![quantize_to_field(0.0)];
160160
let prover = MockProver::run(4, &circuit, vec![expected_output]).unwrap();
161161
assert!(prover.verify().is_ok());
162162
}
@@ -169,7 +169,8 @@ mod tests {
169169
weight_b: vec![-3.0],
170170
};
171171

172-
let expected_output = vec![halo2_proofs::pasta::Fp::from(6u64)];
172+
// Expected output is abs(-1.0) * abs(-2.0) * abs(-3.0) = 6.0
173+
let expected_output = vec![quantize_to_field(6.0)];
173174
let prover = MockProver::run(4, &circuit, vec![expected_output]).unwrap();
174175
assert!(prover.verify().is_ok());
175176
}

src/zklora/libs/zklora_halo2/src/lib.rs

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,74 @@ use halo2_proofs::{
44
dev::MockProver,
55
pasta::Fp,
66
};
7-
use ff::PrimeField;
7+
use ff::{PrimeField, Field};
88
use serde::{Serialize, Deserialize};
99

1010
mod circuit;
1111
use circuit::LoRACircuit;
1212

13+
// Constants for quantization
14+
const SCALE_FACTOR: u64 = 10_000; // 10^4 for 4 decimal places
15+
const SCALE_FACTOR_F64: f64 = SCALE_FACTOR as f64;
16+
17+
/// Convert a floating-point value to a field element using fixed-point arithmetic
18+
fn quantize_to_field<F: PrimeField>(value: f64) -> F {
19+
// Handle special cases
20+
if value == 0.0 {
21+
return F::zero();
22+
}
23+
if value == 1.0 {
24+
return F::one();
25+
}
26+
if value == -1.0 {
27+
return -F::one();
28+
}
29+
30+
// Take absolute value and scale
31+
let abs_val = value.abs();
32+
let scaled = (abs_val * SCALE_FACTOR_F64).round() as u64;
33+
34+
// Convert to field element
35+
let field_val = F::from(scaled);
36+
37+
// Apply sign
38+
if value < 0.0 {
39+
-field_val
40+
} else {
41+
field_val
42+
}
43+
}
44+
45+
/// Convert a field element back to a floating-point value
46+
fn dequantize_from_field<F: PrimeField>(value: F) -> f64 {
47+
// Handle special cases
48+
if value == F::zero() {
49+
return 0.0;
50+
}
51+
if value == F::one() {
52+
return 1.0;
53+
}
54+
if value == -F::one() {
55+
return -1.0;
56+
}
57+
58+
// Convert to bytes and then to u64
59+
let bytes = value.to_repr();
60+
let mut u64_bytes = [0u8; 8];
61+
u64_bytes.copy_from_slice(&bytes[0..8]);
62+
let scaled = u64::from_le_bytes(u64_bytes);
63+
64+
// Dequantize by dividing by scale factor
65+
let dequantized = scaled as f64 / SCALE_FACTOR_F64;
66+
67+
// Handle negative values
68+
if value < F::zero() {
69+
-dequantized
70+
} else {
71+
dequantized
72+
}
73+
}
74+
1375
#[derive(Serialize, Deserialize)]
1476
struct ProofData {
1577
input: Vec<f64>,
@@ -35,17 +97,17 @@ fn generate_proof(
3597

3698
// Calculate expected output for MockProver based on circuit's logic
3799
let i_val = if !input.is_empty() {
38-
Fp::from(input[0].abs() as u64)
100+
quantize_to_field(input[0])
39101
} else {
40102
Fp::zero()
41103
};
42104
let wa_val = if !weight_a.is_empty() {
43-
Fp::from(weight_a[0].abs() as u64)
105+
quantize_to_field(weight_a[0])
44106
} else {
45107
Fp::one()
46108
};
47109
let wb_val = if !weight_b.is_empty() {
48-
Fp::from(weight_b[0].abs() as u64)
110+
quantize_to_field(weight_b[0])
49111
} else {
50112
Fp::one()
51113
};
@@ -99,25 +161,25 @@ fn verify_proof(proof: &[u8], public_inputs: Vec<f64>) -> PyResult<bool> {
99161
// Calculate expected output from public inputs or use proof data
100162
let expected_public_output = if !public_inputs.is_empty() {
101163
// Use provided public inputs
102-
Fp::from(public_inputs[0].abs() as u64)
164+
quantize_to_field(public_inputs[0])
103165
} else {
104166
// Use expected output from proof data when no public inputs provided
105167
Fp::from(proof_data.expected_output)
106168
};
107169

108170
// Calculate the actual output based on the circuit computation
109171
let i_val = if !proof_data.input.is_empty() {
110-
Fp::from(proof_data.input[0].abs() as u64)
172+
quantize_to_field(proof_data.input[0])
111173
} else {
112174
Fp::zero()
113175
};
114176
let wa_val = if !proof_data.weight_a.is_empty() {
115-
Fp::from(proof_data.weight_a[0].abs() as u64)
177+
quantize_to_field(proof_data.weight_a[0])
116178
} else {
117179
Fp::one()
118180
};
119181
let wb_val = if !proof_data.weight_b.is_empty() {
120-
Fp::from(proof_data.weight_b[0].abs() as u64)
182+
quantize_to_field(proof_data.weight_b[0])
121183
} else {
122184
Fp::one()
123185
};
@@ -152,6 +214,44 @@ fn zklora_halo2(_py: Python, m: &PyModule) -> PyResult<()> {
152214
mod tests {
153215
use super::*;
154216

217+
#[test]
218+
fn test_quantization_roundtrip() {
219+
let test_values = vec![
220+
0.0, 1.0, -1.0,
221+
0.1234, -0.1234,
222+
123.456, -123.456,
223+
0.0001, -0.0001,
224+
];
225+
226+
for &value in test_values.iter() {
227+
let field_val: Fp = quantize_to_field(value);
228+
let roundtrip = dequantize_from_field(field_val);
229+
230+
// Check that roundtrip preserves value within epsilon
231+
let epsilon = 1e-4; // Based on SCALE_FACTOR
232+
assert!((value - roundtrip).abs() < epsilon,
233+
"Roundtrip failed for {}: got {}", value, roundtrip);
234+
}
235+
}
236+
237+
#[test]
238+
fn test_special_values() {
239+
// Test zero
240+
let zero_field: Fp = quantize_to_field(0.0);
241+
assert_eq!(zero_field, Fp::zero());
242+
assert_eq!(dequantize_from_field(zero_field), 0.0);
243+
244+
// Test one
245+
let one_field: Fp = quantize_to_field(1.0);
246+
assert_eq!(one_field, Fp::one());
247+
assert_eq!(dequantize_from_field(one_field), 1.0);
248+
249+
// Test negative one
250+
let neg_one_field: Fp = quantize_to_field(-1.0);
251+
assert_eq!(neg_one_field, -Fp::one());
252+
assert_eq!(dequantize_from_field(neg_one_field), -1.0);
253+
}
254+
155255
#[test]
156256
fn test_proof_generation_and_verification() {
157257
Python::with_gil(|py| {

0 commit comments

Comments
 (0)