diff --git a/.gitignore b/.gitignore index 47c5684..a08ba1e 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,9 @@ build/ *.whl .pypirc .coverage + +# Specific Cargo.lock files to ignore +zklora/libs/Cargo.lock +Cargo.lock + +.coverage \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..07842a7 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,5 @@ +[workspace] +members = [ + "src/zklora/libs/plonky3" +] +resolver = "2" \ No newline at end of file diff --git a/readme.md b/README.md similarity index 100% rename from readme.md rename to README.md diff --git a/requirements.txt b/requirements.txt index 2e07a3b..1f8e393 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,8 @@ -numpy -pytest -torch -transformers \ No newline at end of file +torch>=2.0.0 +transformers>=4.30.0 +peft>=0.4.0 +onnx>=1.14.0 +onnxruntime>=1.15.0 +numpy>=1.24.0 +ezkl>=5.0.0 +maturin>=1.9.4 \ No newline at end of file diff --git a/src/requirements.txt b/src/requirements.txt deleted file mode 100644 index 19eb6f0..0000000 --- a/src/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -torch>=2.0.0 -transformers>=4.30.0 -peft>=0.4.0 -onnx>=1.14.0 -onnxruntime>=1.15.0 -numpy>=1.24.0 -ezkl>=5.0.0 diff --git a/src/zklora/__init__.py b/src/zklora/__init__.py index 0cadf1c..711ebe9 100644 --- a/src/zklora/__init__.py +++ b/src/zklora/__init__.py @@ -1,17 +1,25 @@ __version__ = "0.1.2" from .zk_proof_generator import batch_verify_proofs +from .zk_proof_generator import generate_proofs +from .mpi_lora_onnx_exporter import export_lora_onnx_json_mpi from .lora_contributor_mpi import LoRAServer, LoRAServerSocket from .base_model_user_mpi import BaseModelClient from .polynomial_commit import commit_activations, verify_commitment - +from .lora_onnx_exporter import export_lora_submodules +from .fp_coding import fixed_point_encode, fixed_point_decode __all__ = [ - "batch_verify_proofs", - "LoRAServer", - "LoRAServerSocket", - "BaseModelClient", - "commit_activations", - "verify_commitment", - "__version__", + 'batch_verify_proofs', + 'generate_proofs', + 'export_lora_onnx_json_mpi', + 'export_lora_submodules', + 'fixed_point_encode', + 'fixed_point_decode', + 'LoRAServer', + 'LoRAServerSocket', + 'BaseModelClient', + 'commit_activations', + 'verify_commitment', + '__version__', ] diff --git a/src/zklora/fp_coding.py b/src/zklora/fp_coding.py new file mode 100644 index 0000000..3f29d0f --- /dev/null +++ b/src/zklora/fp_coding.py @@ -0,0 +1,88 @@ +import numpy as np + +def fixed_point_encode(x, fractional_bits: int = 10, total_bits: int = 32): + """Encode floating-point values to *signed* fixed-point two's complement integers. + + Parameters + ---------- + x : array-like or scalar + Values convertible to ``np.float16``. + fractional_bits : int, default ``10`` + Number of bits reserved for the fractional part of the fixed-point number. + total_bits : int, default ``32`` + Total bit-width of the resulting fixed-point integer. ``total_bits`` must be + large enough to accommodate the integer, fractional **and** sign bit. 32 is + a sensible default because we return ``np.uint32`` values. + + Returns + ------- + list[int] + Two's-complement encoded unsigned integers (``np.uint32``) representing the + input values at the requested precision. + """ + + # Convert to higher precision float before scaling to avoid precision loss. + arr = np.asarray(x, dtype=np.float32) + multiplier = 2 ** fractional_bits + + # Scale and round towards nearest integer. + fixed = np.round(arr * multiplier).astype(np.int64) # temporary wider type + + # Check that the value fits in the desired bit-width (including sign bit). + max_val = (1 << (total_bits - 1)) - 1 + min_val = -1 << (total_bits - 1) + if np.any(fixed > max_val) or np.any(fixed < min_val): + raise OverflowError( + f"Value out of range for {total_bits}-bit fixed-point representation" + ) + + # Convert to two's complement by masking – this yields an *unsigned* view of + # exactly the same bits. + fixed_twos = (fixed & ((1 << total_bits) - 1)).astype(np.uint32) + + return fixed_twos.tolist() + +# Fixed point decoding function for 16-bit floating point numbers from u32 fixed-point representation +def fixed_point_decode(y, fractional_bits: int = 10, total_bits: int = 32): + """Decode signed fixed-point two's complement integers back to floats. + + Parameters + ---------- + y : array-like or scalar + Unsigned integers (``np.uint32`` or compatible) storing two's-complement + fixed-point values produced by :func:`fixed_point_encode`. + fractional_bits : int, default ``10`` + Number of fractional bits that were used during encoding. + total_bits : int, default ``32`` + Total width of the stored integers. Must match the value passed to + :func:`fixed_point_encode`. + + Returns + ------- + list[np.float16] + The decoded floating-point values. + """ + + arr_unsigned = np.asarray(y, dtype=np.uint32) + + # Re-interpret the unsigned integers as signed two's complement. + signed = arr_unsigned.view(np.int32) if total_bits == 32 else arr_unsigned.astype(np.int64) + + multiplier = 2 ** fractional_bits + decoded = signed.astype(np.float32) / multiplier + return decoded.astype(np.float16).tolist() + +if __name__ == "__main__": + # Basic sanity check for the fixed-point codec. + test_values = np.array([-3.25, -1.5, -0.125, 0.0, 0.125, 1.5, 3.25], dtype=np.float32) + + encoded = fixed_point_encode(test_values, fractional_bits=10) + decoded = np.array(fixed_point_decode(encoded, fractional_bits=10), dtype=np.float32) + + # Expect equality within one LSB of the fractional part. + assert np.allclose(test_values, decoded, atol=1 / (2 ** 10)), ( + f"Round-trip mismatch:\noriginal: {test_values}\ndecoded : {decoded}" + ) + + print("[fp_coding] Self-test passed for signed two's-complement encoding/decoding.") + \ No newline at end of file diff --git a/src/zklora/libs/Cargo.toml b/src/zklora/libs/Cargo.toml new file mode 100644 index 0000000..f670b7b --- /dev/null +++ b/src/zklora/libs/Cargo.toml @@ -0,0 +1,5 @@ +[workspace] +members = [ + "plonky3" +] +resolver = "2" \ No newline at end of file diff --git a/src/zklora/libs/plonky3/Cargo.toml b/src/zklora/libs/plonky3/Cargo.toml new file mode 100644 index 0000000..0a7a5ef --- /dev/null +++ b/src/zklora/libs/plonky3/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "plonky3" +version = "0.1.0" +edition = "2021" + +[lib] +name = "plonky3_py" +crate-type = ["cdylib"] + +[dependencies] +rand = "0.8.5" +pyo3 = { version = "0.24.1", features = ["extension-module"] } + +# plonky3 dependencies +p3-air = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-field = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-matrix = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-mersenne-31 = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-util = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-circle = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-commit = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-dft = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-fri = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-goldilocks = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-blake3 = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-keccak = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-mds = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-merkle-tree = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-poseidon = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-uni-stark = { git = "https://github.com/Plonky3/Plonky3.git" } +bincode = { version = "2.0.1", features = ["serde"] } \ No newline at end of file diff --git a/src/zklora/libs/plonky3/src/lib.rs b/src/zklora/libs/plonky3/src/lib.rs new file mode 100644 index 0000000..5f0bb0c --- /dev/null +++ b/src/zklora/libs/plonky3/src/lib.rs @@ -0,0 +1,309 @@ +use p3_field::{PrimeCharacteristicRing, PrimeField}; +use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use p3_mersenne_31::Mersenne31; +use p3_uni_stark::{prove, verify}; +use pyo3::prelude::*; +use pyo3::types::PyModule; +use pyo3::wrap_pyfunction; +use pyo3::Bound; + +pub mod vector_matrix_air; + +use vector_matrix_air::{VectorMatrixMultiplicationAIR, MyConfig}; + +/// Multiplies a vector by a matrix (vector * matrix). +/// +/// # Arguments +/// * `a` - A reference to a vector of field elements (length m) +/// * `b` - A reference to a matrix of field elements (m x n) +/// +/// # Returns +/// A vector of field elements (length n) representing the result of the multiplication. +/// +/// # Panics +/// Panics if the length of the vector does not match the height of the matrix. +pub fn vector_matrix_multiply(a: &Vec, b: &RowMajorMatrix) -> Vec { + assert_eq!( + a.len(), + b.height(), + "Vector length must match matrix height" + ); + let mut result = vec![F::ZERO; b.width()]; + for i in 0..b.width() { + for j in 0..b.height() { + result[i] += a[j] * b.get(j, i).unwrap(); + } + } + result +} + + +fn vector_matrix_transform( + m: usize, + n: usize, + v: &Vec, + a: &Vec>, +) -> (Vec, RowMajorMatrix) { + assert_eq!(v.len(), m, "Vector length must be m"); + assert_eq!(a.len(), m, "Matrix must have m rows"); + // Convert vector v from u32 to Mersenne31 + let vector: Vec = v.iter().map(|&x| Mersenne31::from_u32(x)).collect(); + + // Flatten the matrix a (Vec>) into a single Vec in row-major order + let mut matrix_flat: Vec = Vec::with_capacity(m * n); + for row in a { + assert_eq!(row.len(), n, "Each row of the matrix must have n columns"); + for &val in row { + matrix_flat.push(Mersenne31::from_u32(val)); + } + } + + let matrix = RowMajorMatrix::new(matrix_flat, n); + (vector, matrix) +} + +/// Generates a zero-knowledge proof for vector-matrix multiplication. +/// +/// This function creates a cryptographic proof that a given vector `v` was correctly +/// multiplied by a matrix `a` to produce a result vector, without revealing the actual +/// computation details. The proof can be verified by anyone without access to the +/// original inputs. +/// +/// # Arguments +/// +/// * `m` - The number of rows in the matrix (and the length of the input vector). +/// * `n` - The number of columns in the matrix (and the length of the output vector). +/// * `v` - A reference to the input vector of `u32` values to be multiplied. +/// * `a` - A reference to the matrix as a vector of vectors of `u32` values. +/// Must be an `m × n` matrix (m rows, n columns). +/// +/// # Returns +/// +/// A `Vec` containing the serialized zero-knowledge proof. This proof can be +/// verified using [`vector_matrix_multiplication_verify`] to confirm that the +/// multiplication was performed correctly without revealing the inputs. +/// +/// # Panics +/// +/// This function will panic if: +/// - The length of vector `v` is not equal to `m` +/// - The matrix `a` does not have exactly `m` rows +/// - Any row in matrix `a` does not have exactly `n` columns +/// +/// # Errors +/// +/// This function will return an error if: +/// - The proof generation fails +/// - The proof serialization fails +/// +/// # Implementation Details +/// +/// The function: +/// 1. Transforms the input vector and matrix into the appropriate field representation +/// 2. Creates a `VectorMatrixMultiplicationAIR` instance for the given dimensions +/// 3. Generates an execution trace for the computation +/// 4. Produces a zero-knowledge proof using the STARK proof system +/// 5. Serializes the proof using bincode for storage/transmission +/// +/// The proof is generated using the Plonky3 STARK system with Mersenne31 field elements. +#[pyfunction] +pub fn vector_matrix_multiplication_prove( + m: usize, + n: usize, + v: Vec, + a: Vec>, +) -> Vec { + let (vector, matrix) = vector_matrix_transform(m, n, &v, &a); + + let air = VectorMatrixMultiplicationAIR::new(m, n); + let trace = air.generate_trace(&vector, &matrix); + let proof = prove(&air.config, &air, trace, &vec![]); + + let config = bincode::config::standard() + .with_little_endian() + .with_fixed_int_encoding(); + bincode::serde::encode_to_vec(proof, config).expect("Failed to serialize proof") +} + +// Add a public helper that verifies a serialized proof. +/// Verify a previously generated vector-matrix multiplication proof. +/// +/// # Arguments +/// * `m` - Number of rows (length of the vector). +/// * `n` - Number of columns of the matrix. +/// * `proof` - A byte vector containing the serialized proof (as produced by +/// [`vector_matrix_multiplication_prove`]). +/// +/// # Returns +/// `true` if the proof is valid, `false` otherwise. +#[pyfunction] +pub fn vector_matrix_multiplication_verify(m: usize, n: usize, proof_bytes: Vec) -> bool { + // Deserialize proof bytes + let config_bin = bincode::config::standard() + .with_little_endian() + .with_fixed_int_encoding(); + + let (proof_deser, _): (p3_uni_stark::Proof, usize) = + match bincode::serde::decode_from_slice(&proof_bytes, config_bin) { + Ok(res) => res, + Err(_) => return false, // invalid encoding + }; + + let air = VectorMatrixMultiplicationAIR::new(m, n); + verify(&air.config, &air, &proof_deser, &vec![]).is_ok() +} + +/// The Python module definition. +/// `m` is the module object that will be returned to Python. +#[pymodule] +fn plonky3_py(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(vector_matrix_multiplication_prove, m)?)?; + m.add_function(wrap_pyfunction!(vector_matrix_multiplication_verify, m)?)?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use p3_field::integers::QuotientMap; + use p3_matrix::dense::RowMajorMatrix; + use p3_mersenne_31::Mersenne31; + use p3_uni_stark::{prove, verify, Proof, StarkGenericConfig}; + + /// Prints the execution trace of a vector-matrix multiplication as a table. + /// + /// # Arguments + /// + /// * `trace` - A reference to a `RowMajorMatrix` representing the execution trace. + /// + /// Each row of the trace corresponds to a step in the computation, and each column contains + /// a value relevant to the AIR (Algebraic Intermediate Representation) for the vector-matrix multiplication. + /// This function prints each row of the trace, with values separated by commas, for debugging and inspection. + fn print_trace(trace: &RowMajorMatrix) { + println!("Trace (one row per line):"); + for i in 0..trace.height() { + let mut row_values = Vec::new(); + for j in 0..trace.width() { + row_values.push(format!("{}", trace.get(i, j).unwrap())); + } + println!("Row {}: [{}]", i, row_values.join(", ")); + } + } + + /// Report the size of the serialized proof. + /// + /// Serializes the given proof instance using bincode and prints the size in bytes. + /// Panics if serialization fails. + fn report_proof_size(proof: &Proof) + where + SC: StarkGenericConfig, + { + let config = bincode::config::standard() + .with_little_endian() + .with_fixed_int_encoding(); + let proof_bytes = + bincode::serde::encode_to_vec(proof, config).expect("Failed to serialize proof"); + println!("Proof size: {} bytes", proof_bytes.len()); + } + + #[test] + fn test_vector_matrix_multiplication_prove() { + let vector = vec![1, 2, 3]; + let matrix = vec![vec![1, 2], vec![4, 5], vec![7, 8]]; + let proof = vector_matrix_multiplication_prove(3, 2, vector, matrix); + + let result = vector_matrix_multiplication_verify(3, 2, proof); + assert!(result); + } + + #[test] + fn test_proving() { + let vector = vec![ + Mersenne31::from_int(1), + Mersenne31::from_int(2), + Mersenne31::from_int(3), + ]; + // [[1, 2], [3, 4]] + let matrix = RowMajorMatrix::new( + vec![ + Mersenne31::from_int(1), + Mersenne31::from_int(2), + Mersenne31::from_int(3), + Mersenne31::from_int(4), + Mersenne31::from_int(5), + Mersenne31::from_int(6), + Mersenne31::from_int(7), + Mersenne31::from_int(8), + Mersenne31::from_int(9), + ], + 3, + ); + + println!("matrix: {:?}", matrix); + + let air = VectorMatrixMultiplicationAIR::new(3, 3); + let trace = air.generate_trace(&vector, &matrix); + println!("trace width: {:?}", trace.width()); + print_trace(&trace); + let proof = prove(&air.config, &air, trace, &vec![]); + report_proof_size(&proof); + let result = verify(&air.config, &air, &proof, &vec![]); + assert!(result.is_ok()); + } + + #[test] + fn test_trace() { + let vector = vec![Mersenne31::from_int(1), Mersenne31::from_int(2)]; + // [[1, 2], [3, 4]] + let matrix = RowMajorMatrix::new( + vec![ + Mersenne31::from_int(1), + Mersenne31::from_int(2), + Mersenne31::from_int(3), + Mersenne31::from_int(4), + ], + 2, + ); + + let air = VectorMatrixMultiplicationAIR::new(2, 2); + let trace = air.generate_trace(&vector, &matrix); + // Print the trace, one row per line + print_trace(&trace); + + // Row 0: [1, 2, 1, 3, 2, 4, 1, 0, 1, 0, 0, 0, 1] + // Row 1: [1, 2, 1, 3, 2, 4, 0, 1, 0, 1, 0, 0, 7] + // Row 2: [1, 2, 1, 3, 2, 4, 1, 0, 0, 0, 1, 0, 2] + // Row 3: [1, 2, 1, 3, 2, 4, 0, 1, 0, 0, 0, 1, 10] + + #[rustfmt::skip] + let correct_trace: RowMajorMatrix = RowMajorMatrix::new( + vec![ + Mersenne31::from_int(1), Mersenne31::from_int(2), Mersenne31::from_int(1), Mersenne31::from_int(3), Mersenne31::from_int(2), Mersenne31::from_int(4), Mersenne31::from_int(1), Mersenne31::from_int(0), Mersenne31::from_int(1), Mersenne31::from_int(0), Mersenne31::from_int(0), Mersenne31::from_int(0), Mersenne31::from_int(1), Mersenne31::from_int(1), + Mersenne31::from_int(1), Mersenne31::from_int(2), Mersenne31::from_int(1), Mersenne31::from_int(3), Mersenne31::from_int(2), Mersenne31::from_int(4), Mersenne31::from_int(0), Mersenne31::from_int(1), Mersenne31::from_int(0), Mersenne31::from_int(1), Mersenne31::from_int(0), Mersenne31::from_int(0), Mersenne31::from_int(7), Mersenne31::from_int(1), + Mersenne31::from_int(1), Mersenne31::from_int(2), Mersenne31::from_int(1), Mersenne31::from_int(3), Mersenne31::from_int(2), Mersenne31::from_int(4), Mersenne31::from_int(1), Mersenne31::from_int(0), Mersenne31::from_int(0), Mersenne31::from_int(0), Mersenne31::from_int(1), Mersenne31::from_int(0), Mersenne31::from_int(2), Mersenne31::from_int(1), + Mersenne31::from_int(1), Mersenne31::from_int(2), Mersenne31::from_int(1), Mersenne31::from_int(3), Mersenne31::from_int(2), Mersenne31::from_int(4), Mersenne31::from_int(0), Mersenne31::from_int(1), Mersenne31::from_int(0), Mersenne31::from_int(0), Mersenne31::from_int(0), Mersenne31::from_int(1), Mersenne31::from_int(10), Mersenne31::from_int(1), + ], + 14, + ); + assert_eq!(trace.width, correct_trace.width); + assert_eq!(trace, correct_trace); + } + + #[test] + fn test_vector_matrix_multiplication() { + let vector = vec![Mersenne31::from_int(1), Mersenne31::from_int(2)]; + let matrix = RowMajorMatrix::new( + vec![ + Mersenne31::from_int(1), + Mersenne31::from_int(2), + Mersenne31::from_int(3), + Mersenne31::from_int(4), + ], + 2, + ); + + let real_result = vec![Mersenne31::from_int(7), Mersenne31::from_int(10)]; + let result = vector_matrix_multiply(&vector, &matrix); + assert_eq!(result, real_result); + } +} diff --git a/src/zklora/libs/plonky3/src/vector_matrix_air.rs b/src/zklora/libs/plonky3/src/vector_matrix_air.rs new file mode 100644 index 0000000..872a2aa --- /dev/null +++ b/src/zklora/libs/plonky3/src/vector_matrix_air.rs @@ -0,0 +1,378 @@ +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_challenger::HashChallenger; +use p3_challenger::SerializingChallenger32; +use p3_circle::CirclePcs; +use p3_commit::ExtensionMmcs; +use p3_field::{extension::BinomialExtensionField, PrimeCharacteristicRing, PrimeField}; +use p3_fri::FriParameters; +use p3_keccak::Keccak256Hash; +use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use p3_merkle_tree::MerkleTreeMmcs; +use p3_mersenne_31::Mersenne31; +use p3_symmetric::CompressionFunctionFromHasher; +use p3_symmetric::SerializingHasher; +use p3_uni_stark::StarkConfig; + + +/// AIR (Algebraic Intermediate Representation) for vector-matrix multiplication. +/// This struct represents the configuration and constraints for proving vector-matrix multiplication +/// using an algebraic execution trace. It tracks the dimensions of the input matrices +/// and provides methods for generating and verifying the computation trace. +pub struct VectorMatrixMultiplicationAIR { + /// Length of vector (number of rows in matrix) + pub m: usize, + /// Number of columns in matrix + pub n: usize, + + pub byte_hash: ByteHash, + pub field_hash: FieldHash, + pub compress: MyCompress, + pub val_mmcs: ValMmcs, + pub challenge_mmcs: ChallengeMmcs, + pub config: MyConfig, + + /// Field element type + _phantom: std::marker::PhantomData, +} + +// Field +type Val = Mersenne31; + +// This creates a cubic extension field over Val using a binomial basis. It's used for generating challenges in the proof system. +// The reason why we want to extend our field for Challenges, is because the original Field size is too small to be brute-forced to solve the challenge. +type Challenge = BinomialExtensionField; +// Your choice of Hash Function +type ByteHash = Keccak256Hash; +// A serializer for Hash function, so that it can take Fields as inputs +type FieldHash = SerializingHasher; +// Defines a compression function type using ByteHash, with 2 input blocks and 32-byte output. +type MyCompress = CompressionFunctionFromHasher; +// Defines a Merkle tree commitment scheme for field elements with 32 levels. +type ValMmcs = MerkleTreeMmcs; +// Defines an extension of the Merkle tree commitment scheme for the challenge field. +type ChallengeMmcs = ExtensionMmcs; +// Defines the challenger type for generating random challenges. +type Challenger = SerializingChallenger32>; +// Defines the polynomial commitment scheme type. +type Pcs = CirclePcs; +// Defines the overall STARK configuration type. +pub type MyConfig = StarkConfig; + +impl VectorMatrixMultiplicationAIR { + pub fn new(m: usize, n: usize) -> Self { + // Declaring an empty hash and its serializer. + let byte_hash = ByteHash {}; + // Declaring Field hash function, it is used to hash field elements in the proof system + let field_hash = FieldHash::new(byte_hash); + // Creates a new instance of the compression function. + let compress = MyCompress::new(byte_hash); + // Instantiates the Merkle tree commitment scheme. + let val_mmcs = ValMmcs::new(field_hash, compress.clone()); + // Creates an instance of the challenge Merkle tree commitment scheme. + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + // Configures the FRI (Fast Reed-Solomon IOP) protocol parameters. + let fri_config = FriParameters { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs.clone(), + log_final_poly_len: 1, + }; + // Instantiates the polynomial commitment scheme with the above parameters. + let pcs = Pcs { + mmcs: val_mmcs.clone(), + fri_params: fri_config, + _phantom: std::marker::PhantomData, + }; + let challenger = Challenger::from_hasher(vec![], byte_hash); + // Creates the STARK configuration instance. + let config = MyConfig::new(pcs, challenger); + + Self { + m, + n, + byte_hash, + field_hash, + compress, + val_mmcs, + challenge_mmcs, + config, + _phantom: std::marker::PhantomData, + } + } + + /// Returns the number of columns in the execution trace + /// + /// The execution trace for matrix multiplication requires columns to store: + /// - All elements from vector (m columns) + /// - All elements from matrix (n * m columns) + /// - m + (m * n) columns for the selectors + /// - One column for the running sum of the current element being computed + /// - One column for the row in the trace that is enabled + /// + /// The total width is calculated as: 2 * (m + m * n) + 1 + pub fn trace_width(&self) -> usize { + 2 * (self.m + self.m * self.n) + 1 + 1 + } + + /// Pushes the matrix elements to the trace data with column major order + fn push_matrix(&self, trace_data: &mut Vec, a: &RowMajorMatrix) { + for i in 0..self.n { + for j in 0..self.m { + trace_data.push(a.get(j, i).unwrap()); + } + } + } + + /// Generates a computation trace for vector-matrix multiplication + /// + /// # Arguments + /// * `v` - Input vector (length m) with field elements + /// * `a` - Matrix (m x n) with field elements + /// + /// # Returns + /// A matrix representing the execution trace, where each row is a step in the computation + /// and columns correspond to: + /// - Columns 0..m: All elements from vector v + /// - Columns m..(m+m*n): All elements from matrix a (in column-major order) + /// - Columns (m+m*n)..(m+m*n+m): Vector selector (one-hot encoding for vector elements) + /// - Columns (m+m*n+m)..(m+m*n+m+m*n): Matrix selector (one-hot encoding for matrix elements) + /// - Column trace_width()-2: Running sum for the current dot product computation + /// - Column trace_width()-1: Enabled flag (1 if row is active, 0 if padding) + pub fn generate_trace(&self, v: &Vec, a: &RowMajorMatrix) -> RowMajorMatrix { + assert_eq!( + a.height(), + self.m, + "Matrix height should match AIR configuration" + ); + assert_eq!( + a.width(), + self.n, + "Matrix width should match AIR configuration" + ); + assert_eq!( + v.len(), + self.m, + "Vector length should match AIR configuration" + ); + + // Compute total number of steps needed for the trace + // For each element V[i], we need m steps to compute the dot product + let total_rows = self.m * self.n; + + // Initialize the trace matrix with F elements + let mut trace_data: Vec = Vec::with_capacity(total_rows * self.trace_width()); + + let mut vector_selector: Vec = vec![F::ONE] + .into_iter() + .chain(std::iter::repeat(F::ZERO).take(self.m - 1)) + .collect(); + let mut matrix_selector: Vec = vec![F::ONE] + .into_iter() + .chain(std::iter::repeat(F::ZERO).take(self.m * self.n - 1)) + .collect(); + + let mut previous_sum = F::ZERO; + + // Generate the step-by-step trace + for _ in 0..total_rows { + trace_data.extend_from_slice(v); + self.push_matrix(&mut trace_data, a); + trace_data.extend_from_slice(&vector_selector); + trace_data.extend_from_slice(&matrix_selector); + + // Find the index in vector_selector where the value is F::ONE + let vector_index = vector_selector + .iter() + .position(|&x| x == F::ONE) + .expect("vector_selector should contain F::ONE"); + + // Get the index in matrix_selector where the value is F::ONE + let matrix_index = matrix_selector + .iter() + .position(|&x| x == F::ONE) + .expect("matrix_selector should contain F::ONE"); + + let running_sum = if vector_index > 0 { + trace_data[vector_index] * trace_data[self.m + matrix_index] + previous_sum + } else { + trace_data[vector_index] * trace_data[self.m + matrix_index] + }; + trace_data.push(running_sum); + previous_sum = running_sum.clone(); + + trace_data.push(F::ONE); + + vector_selector.rotate_right(1); + matrix_selector.rotate_right(1); + } + + // If the trace length is not a power of two, pad it with dummy rows of zeros so that + // the total number of rows is the next power of two. The last column (running sum) + // is explicitly set to 0 for these padding rows to indicate that they are inactive. + + let width = self.trace_width(); + let padded_rows = total_rows.next_power_of_two(); + if padded_rows > total_rows { + for _ in 0..(padded_rows - total_rows) { + // Push `width` zeros. Since we are inside a PrimeField, F::ZERO is valid. + // The last column (running sum) remains 0 as well. + trace_data.extend(std::iter::repeat(F::ZERO).take(width)); + } + } + + RowMajorMatrix::new(trace_data, width) + } +} + +impl BaseAir for VectorMatrixMultiplicationAIR { + fn width(&self) -> usize { + self.trace_width() + } +} + +impl Air for VectorMatrixMultiplicationAIR +where + AB::F: PrimeField, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let current = main.row_slice(0).unwrap(); + let next = main.row_slice(1).unwrap(); + + let v_sel_init = self.n * self.m + self.m; + let m_sel_init = self.n * self.m + self.m + self.m; + let matrix_init = self.m; + let sum = self.trace_width() - 2; + let enabled = self.trace_width() - 1; + + // Enforce starting state + // the row is enabled + builder + .when_first_row() + .assert_one(current[enabled].clone()); + + // sum equal the first element of the vector times the first element of the matrix + builder.when_first_row().assert_eq( + current[sum].clone(), + current[0].clone() * current[matrix_init].clone(), + ); + + // The first element of the vector selector is 1 + builder + .when_first_row() + .assert_one(current[v_sel_init].clone()); + // The rest of the vector selectors are 0 + for i in 1..self.m { + builder + .when_first_row() + .assert_zero(current[v_sel_init + i].clone()); + } + + // The first element of the matrix selector is 1 + builder + .when_first_row() + .assert_one(current[m_sel_init].clone()); + // The rest of the matrix selectors are 0 + for i in 1..self.m { + builder + .when_first_row() + .assert_zero(current[m_sel_init + i].clone()); + } + + // Enforce final enabled row is followed by disabled row + builder + .when_transition() + .when(current[enabled].clone()) + .when(current[sum - 1].clone()) + .assert_zero(next[enabled].clone()); + + // Enforce rows are all 0 in the last column after the last enabled row + builder + .when_transition() + .when(AB::Expr::ONE - current[enabled].clone()) + .assert_zero(next[enabled].clone()); + + // Enforce booleanity of the vector selector + for i in 0..self.m { + builder + .when_transition() + .when(current[enabled].clone()) + .assert_bool(current[v_sel_init + i].clone()); + } + + // Enforce booleanity of the matrix selector + for i in 0..self.m { + builder + .when_transition() + .when(current[enabled].clone()) + .assert_bool(current[m_sel_init + i].clone()); + } + + // Enforce booleanity of the enabled colum + builder + .when_transition() + .assert_bool(current[enabled].clone()); + builder.when_last_row().assert_bool(next[enabled].clone()); + + // Enforce the sum of the vector selector is 1 + let mut acum = AB::Expr::ZERO; + for i in 0..self.m { + acum += current[v_sel_init + i].clone(); + } + builder + .when_transition() + .when(current[enabled].clone()) + .assert_eq(acum, AB::Expr::ONE); + + // Enforce the sum of the matrix selector is 1 + let mut acum = AB::Expr::ZERO; + for i in 0..self.m * self.n { + acum += current[m_sel_init + i].clone(); + } + builder + .when_transition() + .when(current[enabled].clone()) + .assert_eq(acum, AB::Expr::ONE); + + // Enforce the vector and matrix do not change between rows + for i in 0..self.m + self.m * self.n { + builder + .when_transition() + .when(current[enabled].clone()) + .when(next[enabled].clone()) + .assert_eq(current[i].clone(), next[i].clone()); + } + + // Enforce the correct vector-matrix multiplication result + // If the first element of the vector selector is 1, then + // the sum colum does not accumulate from the previous row + for i in 0..self.m * self.n { + builder + .when_transition() + .when(current[enabled].clone()) + .when(current[v_sel_init].clone()) + .when(current[m_sel_init + i].clone()) + .assert_eq( + current[sum].clone(), + current[0].clone() * current[matrix_init + i].clone(), + ); + } + // If the first element of the vector selector is 0, then + // the sum colum accumulates from the previous row + for i in 1..self.m { + for j in 0..self.m * self.n { + builder + .when_transition() + .when(next[enabled].clone()) + .when(AB::Expr::ONE - next[v_sel_init].clone()) + .when(next[v_sel_init + i].clone()) + .when(next[m_sel_init + j].clone()) + .assert_eq( + next[sum].clone(), + current[sum].clone() + next[i].clone() * next[matrix_init + j].clone(), + ); + } + } + } +} diff --git a/src/zklora/lora_onnx_exporter.py b/src/zklora/lora_onnx_exporter.py new file mode 100644 index 0000000..974c368 --- /dev/null +++ b/src/zklora/lora_onnx_exporter.py @@ -0,0 +1,270 @@ +import json +import os +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from peft import PeftModel +from transformers import PreTrainedTokenizer + + +# A helper to fix shapes for A, B +def normalize_lora_matrices( + A: torch.Tensor, B: torch.Tensor, x_data: np.ndarray +) -> tuple[torch.Tensor, torch.Tensor, int, int, int]: + """ + x_data shape => (batch, seq_len, hidden_dim). + We ensure A => [hidden_dim, r], B => [r, out_dim]. + """ + in_dim = x_data.shape[-1] + a0, a1 = A.shape + # A => [in_dim, r] + if a0 == in_dim: + r = a1 + elif a1 == in_dim: + A = A.transpose(0, 1) + r = A.shape[1] + else: + raise ValueError(f"A shape {A.shape} doesn't match x_data last dim {in_dim}.") + + b0, b1 = B.shape + if b0 == r: + out_dim = b1 + elif b1 == r: + B = B.transpose(0, 1) + out_dim = B.shape[1] + else: + raise ValueError(f"B shape {B.shape} doesn't match rank={r} in any dimension.") + return A, B, in_dim, r, out_dim + + +class LoraShapeTransformer(nn.Module): + """ + Expects shape (1, batch*seq_len*hidden_dim). + Internal forward => reshape to (batch, seq_len, hidden_dim). + """ + + def __init__(self, A, B, batch_size, seq_len, hidden_dim): + super().__init__() + self.register_buffer("A", A) + self.register_buffer("B", B) + self.batch_size = batch_size + self.seq_len = seq_len + self.hidden_dim = hidden_dim + + def forward(self, x_1d: torch.Tensor) -> torch.Tensor: + # x_1d => shape (1, total_size) + x_3d = x_1d.view(self.batch_size, self.seq_len, self.hidden_dim) + out_3d = (x_3d @ self.A) @ self.B + out_3d = out_3d + x_3d.mean() + self.A.sum() + self.B.sum() + # Flatten output for demonstration + out_2d = out_3d.view(1, -1) + return out_2d + + +def make_activation_hook(mod_name: str, activation_map: dict) -> callable: + """Creates a hook function for capturing LoRA submodule activations.""" + + def hook(mod, layer_inputs, layer_output) -> None: + if not layer_inputs: + return + x = layer_inputs[0] # shape: (batch, seq_len, hidden_dim) + activation_map[mod_name] = x.detach().cpu().float().numpy() + + return hook + + +def register_lora_hooks( + model: PeftModel, activation_map: dict, submodule_key: str = None +) -> None: + """ + Recursively finds LoRA submodules and registers forward hooks. + Args: + model: The PEFT model to hook + activation_map: Dictionary to store activations + submodule_key: If set, only hook submodules containing this key + """ + issued_wte_warning = False + + for full_name, module in model.named_modules(): + # Check if this submodule has LoRA + if hasattr(module, "lora_A") and hasattr(module, "lora_B"): + # Skip embedding submodules + if "wte" in full_name or "wpe" in full_name: + if not issued_wte_warning: + print( + "WARNING: Found LoRA submodule '{full_name}' (wte/wpe). " + "Skipping hooking embeddings." + ) + issued_wte_warning = True + continue + + # If user wants to filter e.g. "attn.c_attn" + if submodule_key and submodule_key not in full_name: + continue + + module.register_forward_hook( + make_activation_hook(full_name, activation_map) + ) + + +def export_lora_submodules( + model: PeftModel, + tokenizer: PreTrainedTokenizer, + input_texts: list[str], + output_dir: str = "lora_onnx_params", + json_dir: str = "intermediate_activations", + submodule_key: str = None, + verbose: bool = False, +) -> None: + """ + 1) Captures LoRA sub-layer inputs with shape (batch, seq_len, hidden_dim). + 2) Flattens them into shape (1, batch*seq_len*hidden_dim). + 3) Exports an ONNX submodule that expects (1, total_size). + - Inside that submodule's forward pass, it reshapes back to (batch, seq_len, hidden_dim). + 4) Writes a JSON file containing a single row of floats ( shape => (1, total_size) ). + + This function alone does not generate proofs; it only creates the ONNX/JSON pairs. + You can run your separate proof-generation code (like `generate_proofs_async`) on them. + + Args: + model: A LoRA-augmented (PEFT) model, in eval mode. + tokenizer: A tokenizer (from the same or compatible base model). + input_texts: A list of strings for batched input. e.g. ["Hello", "More text", ...] + output_dir: Where to save ONNX files. + json_dir: Where to save JSON files. + submodule_key: If set (e.g. "attn.c_attn"), export only submodules containing that key. + verbose: If True, print debug info. + """ + + # Ensure directories exist + os.makedirs(output_dir, exist_ok=True) + os.makedirs(json_dir, exist_ok=True) + + # Ensure we can pad if GPT-2-like + if tokenizer.pad_token is None: + # Option 1: use eos token as pad + tokenizer.pad_token = tokenizer.eos_token + # Option 2 (alternative): + # tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + # model.resize_token_embeddings(len(tokenizer)) + + # We'll store each sub-layer input in a dict + activation_map = {} + + # Register hooks before forward pass + register_lora_hooks(model, activation_map, submodule_key) + + # Tokenize the input text as a single batch + inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True) + input_ids = inputs["input_ids"] + # e.g. shape: (batch, seq_len) + if verbose: + print("input_ids shape:", input_ids.shape) + + # Run forward pass + with torch.no_grad(): + _ = model(input_ids) + + # If no sub-layer activations were captured + if len(activation_map) == 0: + print( + "No LoRA sub-layer activations captured. Possibly no triggers for these inputs." + ) + return + + # For each submodule hooking + for full_name, x_data in activation_map.items(): + # x_data shape => (batch, seq_len, hidden_dim) + batch_size = x_data.shape[0] + seq_len = x_data.shape[1] + hidden_dim = x_data.shape[2] + + total_size = batch_size * seq_len * hidden_dim # e.g. 3*4*768=9216 + # Flatten to => (1, total_size) + one_row = x_data.reshape(1, total_size) + + # Look up the submodule + submodule = dict(model.named_modules()).get(full_name, None) + if submodule is None: + print(f"Cannot find submodule {full_name} in the model dict, skipping.") + continue + + # Extract A,B + if hasattr(submodule.lora_A, "keys"): + a_keys = list(submodule.lora_A.keys()) + if not a_keys: + print(f"No keys in submodule.lora_A for {full_name}, skipping.") + continue + A_mod = submodule.lora_A[a_keys[0]] + else: + A_mod = submodule.lora_A + + if hasattr(submodule.lora_B, "keys"): + b_keys = list(submodule.lora_B.keys()) + if not b_keys: + print(f"No keys in submodule.lora_B for {full_name}, skipping.") + continue + B_mod = submodule.lora_B[b_keys[0]] + else: + B_mod = submodule.lora_B + + if not hasattr(A_mod, "weight"): + print(f"No weight in lora_A for {full_name}, skipping.") + continue + if not hasattr(B_mod, "weight"): + print(f"No weight in lora_B for {full_name}, skipping.") + continue + + A_raw = A_mod.weight.detach().cpu().float() + B_raw = B_mod.weight.detach().cpu().float() + + # fix shapes + try: + A_fixed, B_fixed, in_dim, rank, out_dim = normalize_lora_matrices( + A_raw, B_raw, x_data + ) + except ValueError as e: + print(f"Shape fix error for {full_name}: {e}") + continue + + # Build sub-module expecting => (1, total_size) + lora_mod = LoraShapeTransformer( + A_fixed, B_fixed, batch_size, seq_len, hidden_dim + ).eval() + + # Save ONNX + safe_name = full_name.replace(".", "_").replace("/", "_") + onnx_path = os.path.join(output_dir, f"{safe_name}.onnx") + + x_tensor = torch.from_numpy(one_row) + import onnx + from torch.onnx import TrainingMode + + try: + torch.onnx.export( + lora_mod, + x_tensor, + onnx_path, + export_params=True, + do_constant_folding=False, + opset_version=12, + input_names=["input_x"], + output_names=["output"], + training=TrainingMode.TRAINING, + keep_initializers_as_inputs=False, + ) + except Exception as e: + print(f"Export error for {full_name}: {e}") + continue + + # Save JSON => single row of shape (1, total_size) + data_json = {"input_data": one_row.tolist()} + json_path = os.path.join(json_dir, f"{safe_name}.json") + with open(json_path, "w") as f: + json.dump(data_json, f) + + if verbose: + print(f"Exported ONNX for {full_name} -> {onnx_path}") + print(f"Saved JSON -> {json_path}, shape => {one_row.shape}") diff --git a/src/zklora/plonky3_prove_verify.py b/src/zklora/plonky3_prove_verify.py new file mode 100644 index 0000000..6d6b70b --- /dev/null +++ b/src/zklora/plonky3_prove_verify.py @@ -0,0 +1,60 @@ +import plonky3_py as pl + +from zklora import export_lora_submodules, generate_proofs, batch_verify_proofs + +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import PeftModel +import asyncio +import ezkl + +# Fixed point encoding function for 16-bit floating point numbers +import numpy as np +from zklora.fp_coding import fixed_point_encode, fixed_point_decode + + +# Patch ezkl.gen_witness to be awaitable so generate_proofs can await it +if not asyncio.iscoroutinefunction(getattr(ezkl, "gen_witness", None)): + _orig_gen_witness = ezkl.gen_witness + + async def _gen_witness_async(*args, **kwargs): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: _orig_gen_witness(*args, **kwargs)) + + ezkl.gen_witness = _gen_witness_async # type: ignore + +v = [0.1, 0.2, 0.3] +A = [[0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9]] + +v_encoded = fixed_point_encode(v) +A_encoded = [fixed_point_encode(row) for row in A] + +proof = pl.vector_matrix_multiplication_prove(3, 3, v_encoded, A_encoded) +assert pl.vector_matrix_multiplication_verify(3, 3, proof) +print("proof length:", len(proof)) + + + +async def main(): + base_model_name = "distilgpt2" + lora_model_name = "q1e123/peft-starcoder-lora-a100" + base_model = AutoModelForCausalLM.from_pretrained(base_model_name) + lora_model = PeftModel.from_pretrained(base_model, lora_model_name) + tokenizer = AutoTokenizer.from_pretrained(base_model_name) + lora_model.eval() + + texts = ["Hello from LoRA", "And another test", "One more line..."] + + export_lora_submodules( + model=lora_model, + tokenizer=tokenizer, + input_texts=texts, + submodule_key="attn.c_attn", + ) + + await generate_proofs(verbose=True, zk_backend="plonky3") + +if __name__ == "__main__": + asyncio.run(main()) + print("Done") diff --git a/src/zklora/strategy.md b/src/zklora/strategy.md new file mode 100644 index 0000000..6bad45f --- /dev/null +++ b/src/zklora/strategy.md @@ -0,0 +1,420 @@ +Root cause + +The Rust AIR behind `pl.vector_matrix_multiplication_prove` builds an execution-trace matrix whose size scales as + + rows = m · n + cols = 2 · (m + m·n) + 2 ≈ 2 m (n+1) + +Total cells ≈ m · n · 2 m (n+1)  →  Θ(m² n²) + +For your real model layer + m = 768, n = 2 304 → + rows = 1 769 472 + cols ≈ 3 543 682 + +Cells ≈ 6.3 × 10¹² → ~25 TB at 4 bytes/field-element ⇒ the allocator request you saw. + +Why it is cubic + +The current AIR keeps the entire matrix (m × n) and two one-hot selector columns in every trace row. That duplicates the input m · n times, hence the quadratic blow-up in both m and n. + +Fix: redesign the proof interface to avoid storing the full matrix in every row + +1. Prove a single dot-product at a time + • Treat each output coordinate yₖ = v·Wₖ (Wₖ = k-th column) as its own STARK. + • Now n = 1 in the AIR ⇒ + rows = m, cols ≈ 2 m+2 → O(m²) elements ≈ 1.2 M (3–5 MB). + • Repeat/aggregate the proofs for all 2 304 columns (or only the ones you need). + +2. Or keep vector-matrix multiplication but store the vector once and stream the matrix column-major through the trace (add running product accumulators instead of one-hot selectors). That brings the width down to O(m+n). + +Quick path you can apply today + +Modify the Python driver in `zk_proof_generator.py` to call the prover column-wise: + +```python +for k in range(n): # n = 2304 + col_k = [row[k] for row in W] # length m + col_k_encoded = fixed_point_encode(col_k, fractional_bits=24) + proof_k = pl.vector_dot_prove(m, v_encoded, col_k_encoded) + # or use the existing function once you expose a dot-product variant +``` + +Corresponding Rust addition (sketch): + +```rust +#[pyfunction] +fn vector_dot_prove(py: Python<'_>, m: usize, v: Vec, w: Vec) -> PyResult> { + // AIR with n = 1; reuse existing code paths but skip matrix selector, etc. +} +``` + +Because each proof is only a few MB, you can batch-verify them afterwards with Plonky3’s multi-proof verifier. + +Long-term + +• Refactor `VectorMatrixMultiplicationAIR::trace_width` to avoid the `m*n` term. +• Replace explicit one-hot selectors with arithmetic constraints (`sel_i` computed on the fly). +• Consider folding rows so the trace has O(m+n) size. + +Until that redesign lands, splitting into per-column (or per-row) proofs is the simplest unblocker and will run comfortably within RAM. + + +# Proof Compression + +Yes, you can create a single compressed proof from multiple proofs! This is called **proof aggregation** or **recursive proving**. There are a few approaches: + +## 1. Recursive Proof Aggregation (Most Common) + +Instead of just batch-verifying multiple proofs, you can create a "proof of proofs" – a single proof that attests to the validity of all the original proofs. This typically works by: + +1. **Creating individual proofs** for each computation +2. **Creating a verification circuit** that checks all the original proofs +3. **Proving the verification circuit** to get a single aggregated proof + +## 2. Current Plonky3 Limitations & Alternatives + +**Plonky3 doesn't have built-in recursion** like Plonky2 does. Here are your options: + +### Option A: Use Plonky2 for Recursion +```rust +// Generate Plonky3 proofs first +let plonky3_proofs = generate_all_plonky3_proofs(); + +// Use Plonky2 to recursively aggregate them +// Plonky2 can verify Plonky3 proofs inside a circuit +let aggregated_proof = plonky2_recursive_aggregation(plonky3_proofs); +``` + +### Option B: Use a SNARK Wrapper +Many projects use a two-stage approach: +1. Generate STARK proofs with Plonky3 (fast proving, larger proofs) +2. Wrap them in a SNARK (slower proving, tiny proof) + +```rust +// Stage 1: Generate Plonky3 STARKs +let stark_proofs = vec![proof1, proof2, proof3]; + +// Stage 2: Aggregate into a single SNARK +// Using Groth16, Plonk, or another SNARK system +let snark_proof = snark_aggregate(stark_proofs); // ~200-500 bytes +``` + +### Option C: Custom Aggregation Circuit in Plonky3 +You could write an AIR that verifies other proofs: + +```rust +// Pseudocode for a proof aggregation AIR +struct ProofAggregationAIR { + num_proofs: usize, + // Store commitments and public inputs +} + +impl Air for ProofAggregationAIR { + fn eval(&self, builder: &mut AB) { + // Verify Merkle openings for each proof + // Check polynomial evaluations + // Aggregate FRI checks + } +} +``` + +## 3. Trade-offs + +| Approach | Proof Size | Proving Time | Verification Time | +|----------|------------|--------------|-------------------| +| **Multi-verify** (current) | N × original | 0 (no new proof) | O(log N) | +| **Recursive aggregation** | 1 × original | O(N) | O(1) | +| **SNARK wrapping** | ~200-500 bytes | O(N) + SNARK time | O(1), very fast | + +## 4. Practical Recommendation + +For your ZKLoRA project, if you need a single compressed proof: + +1. **Short term**: Use `verify_multi` for batch verification (what you have now) +2. **Medium term**: Add a SNARK wrapper (e.g., using Groth16 via arkworks or snarkjs) to compress all proofs into one tiny proof +3. **Long term**: Wait for Plonky3 to add native recursion, or use Plonky2 for the aggregation layer + +Example SNARK wrapper approach: +```python +# Generate all Plonky3 proofs +plonky3_proofs = [generate_proof(data) for data in datasets] + +# Aggregate into single SNARK proof (~300 bytes) +aggregated_proof = snark_wrapper.aggregate_proofs(plonky3_proofs) + +# Now you have ONE proof that proves all computations +verify(aggregated_proof) # Fast, constant time +``` + +This is exactly what projects like zkEVM and zkBridge do – they generate many STARK proofs for efficiency, then aggregate them into a single SNARK for on-chain verification. + + +# Proof compression with arkworks and Groth16 + +Here's how to use **arkworks and Groth16** to compress multiple Plonky3 proofs into a single SNARK proof: + +## 1. Setup Dependencies + +First, add arkworks to your `Cargo.toml`: + +```toml +[dependencies] +ark-groth16 = "0.4" +ark-std = "0.4" +ark-ff = "0.4" +ark-ec = "0.4" +ark-relations = "0.4" +ark-r1cs-std = "0.4" +ark-crypto-primitives = "0.4" +ark-bn254 = "0.4" # or ark-bls12-381 +ark-serialize = "0.4" +``` + +## 2. Create a Verification Circuit + +The key is to build a circuit that verifies your Plonky3 proofs: + +```rust +use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; +use ark_r1cs_std::prelude::*; +use ark_ff::PrimeField; + +pub struct Plonky3VerifierCircuit { + // Serialized Plonky3 proofs + proofs: Vec>, + // Public inputs for each proof + public_inputs: Vec>, + // Expected outputs/commitments + expected_outputs: Vec, +} + +impl ConstraintSynthesizer for Plonky3VerifierCircuit { + fn generate_constraints( + self, + cs: ConstraintSystemRef, + ) -> Result<(), SynthesisError> { + // For each Plonky3 proof + for (i, proof) in self.proofs.iter().enumerate() { + // Allocate proof bytes as witnesses + let proof_vars = proof.iter() + .map(|&byte| UInt8::new_witness(cs.clone(), || Ok(byte))) + .collect::, _>>()?; + + // Implement Plonky3 verification logic + // This is complex - you need to: + // 1. Verify Merkle commitments + // 2. Check FRI queries + // 3. Validate polynomial evaluations + verify_plonky3_in_circuit( + cs.clone(), + &proof_vars, + &self.public_inputs[i], + &self.expected_outputs[i], + )?; + } + Ok(()) + } +} +``` + +## 3. Simplified Approach: Hash-based Aggregation + +Since implementing full Plonky3 verification in a circuit is complex, here's a more practical approach using **commitment aggregation**: + +```rust +use ark_groth16::{Groth16, ProvingKey, VerifyingKey}; +use ark_bn254::{Bn254, Fr}; // Or your chosen curve +use ark_crypto_primitives::crh::sha256::Sha256; +use ark_r1cs_std::bits::uint8::UInt8; + +/// Circuit that verifies hash commitments of multiple proofs +pub struct ProofAggregatorCircuit { + // Hash of each Plonky3 proof + proof_hashes: Vec<[u8; 32]>, + // Merkle root of all proof hashes + merkle_root: [u8; 32], +} + +impl ConstraintSynthesizer for ProofAggregatorCircuit { + fn generate_constraints( + self, + cs: ConstraintSystemRef, + ) -> Result<(), SynthesisError> { + // Allocate proof hashes + let mut hash_vars = vec![]; + for hash in &self.proof_hashes { + let hash_var = hash.iter() + .map(|&b| UInt8::new_witness(cs.clone(), || Ok(b))) + .collect::, _>>()?; + hash_vars.push(hash_var); + } + + // Compute Merkle tree in-circuit + let computed_root = compute_merkle_root_circuit(cs.clone(), &hash_vars)?; + + // Verify root matches expected + let expected_root = self.merkle_root.iter() + .map(|&b| UInt8::new_input(cs.clone(), || Ok(b))) + .collect::, _>>()?; + + computed_root.enforce_equal(&expected_root)?; + + Ok(()) + } +} +``` + +## 4. Generate the Aggregated Proof + +```rust +use ark_groth16::{prepare_verifying_key, verify_proof}; +use ark_std::rand::rngs::OsRng; +use sha2::{Sha256, Digest}; + +pub fn aggregate_plonky3_proofs( + plonky3_proofs: Vec>, +) -> Result, Box> { + // Step 1: Hash each proof + let proof_hashes: Vec<[u8; 32]> = plonky3_proofs + .iter() + .map(|proof| { + let mut hasher = Sha256::new(); + hasher.update(proof); + hasher.finalize().into() + }) + .collect(); + + // Step 2: Build Merkle tree + let merkle_root = build_merkle_root(&proof_hashes); + + // Step 3: Create circuit + let circuit = ProofAggregatorCircuit { + proof_hashes: proof_hashes.clone(), + merkle_root, + }; + + // Step 4: Setup (do this once and save keys) + let mut rng = OsRng; + let (pk, vk) = Groth16::::circuit_specific_setup( + circuit.clone(), + &mut rng + )?; + + // Step 5: Generate proof + let proof = Groth16::prove(&pk, circuit, &mut rng)?; + + // Serialize proof + let mut proof_bytes = Vec::new(); + proof.serialize_compressed(&mut proof_bytes)?; + + Ok(proof_bytes) +} +``` + +## 5. Verify the Aggregated Proof + +```rust +pub fn verify_aggregated_proof( + proof_bytes: &[u8], + merkle_root: [u8; 32], + vk: &VerifyingKey, +) -> bool { + // Deserialize proof + let proof = match Groth16Proof::deserialize_compressed(proof_bytes) { + Ok(p) => p, + Err(_) => return false, + }; + + // Public inputs = merkle root + let public_inputs = merkle_root + .iter() + .map(|&b| Fr::from(b as u64)) + .collect::>(); + + // Verify + let pvk = prepare_verifying_key(vk); + verify_proof(&pvk, &proof, &public_inputs).is_ok() +} +``` + +## 6. Python Integration + +Add Python bindings using PyO3: + +```rust +use pyo3::prelude::*; + +#[pyfunction] +pub fn compress_plonky3_proofs(proofs: Vec>) -> PyResult> { + aggregate_plonky3_proofs(proofs) + .map_err(|e| PyErr::new::( + format!("Failed to aggregate: {}", e) + )) +} + +#[pyfunction] +pub fn verify_compressed_proof( + proof: Vec, + merkle_root: Vec +) -> PyResult { + // Load VK from file or constant + let vk = load_verifying_key()?; + let mut root = [0u8; 32]; + root.copy_from_slice(&merkle_root); + + Ok(verify_aggregated_proof(&proof, root, &vk)) +} +``` + +## 7. Usage in Python + +```python +import plonky3_py as pl +import groth16_aggregator as g16 + +# Generate individual Plonky3 proofs +proofs = [] +for m, n, v, a in datasets: + proof = pl.vector_matrix_multiplication_prove(m, n, v, a) + proofs.append(proof) + +# Compress all proofs into one ~256-byte Groth16 proof +compressed_proof = g16.compress_plonky3_proofs(proofs) +print(f"Compressed {len(proofs)} proofs into {len(compressed_proof)} bytes") + +# Verify the single compressed proof +is_valid = g16.verify_compressed_proof(compressed_proof, merkle_root) +assert is_valid +``` + +## Trade-offs & Tips + +1. **Trusted Setup**: Groth16 requires a trusted setup ceremony. For production, use Powers of Tau. + +2. **Alternative: Use Plonk/Marlin**: These are universal (no circuit-specific setup): + ```toml + ark-plonk = "0.4" # or ark-marlin + ``` + +3. **Optimization**: Store proof commitments on-chain instead of full Merkle tree: + ```solidity + contract ProofRegistry { + mapping(bytes32 => bool) validProofHashes; + + function verifyAggregatedProof( + bytes calldata groth16Proof, + bytes32 merkleRoot + ) external view returns (bool) { + // Verify Groth16 proof on-chain + return verifier.verifyProof(groth16Proof, merkleRoot); + } + } + ``` + +4. **Simpler Alternative**: If you just need to prove "all these proofs are valid", you can: + - Hash all Plonky3 proofs together + - Create a Groth16 proof of the hash computation + - This is much simpler than verifying Plonky3 in-circuit + +This approach gives you a single ~256-byte proof that can be verified in ~2ms, perfect for on-chain verification or bandwidth-constrained environments. \ No newline at end of file diff --git a/src/zklora/zk_proof_generator.py b/src/zklora/zk_proof_generator.py index e49d08a..329ea08 100644 --- a/src/zklora/zk_proof_generator.py +++ b/src/zklora/zk_proof_generator.py @@ -3,12 +3,17 @@ import json import time import asyncio -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, Literal import numpy as np import onnx import onnxruntime import ezkl +# Added for extracting tensors in plonky3 backend +from onnx import numpy_helper +import torch +from zklora.fp_coding import fixed_point_encode, fixed_point_decode +import plonky3_py as pl class ProofPaths(NamedTuple): @@ -83,12 +88,14 @@ def batch_verify_proofs( print(f"Total proofs verified: {len(proof_files)}") return total_verify_time, len(proof_files) +ZKBackend = Literal["ezkl", "plonky3"] async def generate_proofs( onnx_dir: str = "lora_onnx_params", json_dir: str = "intermediate_activations", output_dir: str = "proof_artifacts", verbose: bool = False, + zk_backend: ZKBackend = "ezkl", ) -> Optional[tuple[float, float, float, int, int]]: """Asynchronously scans onnx_dir for .onnx files and json_dir for .json files. For each matching pair, runs: @@ -156,85 +163,133 @@ async def generate_proofs( proof_file, ) = names - py_args = ezkl.PyRunArgs() - py_args.input_visibility = "public" - py_args.output_visibility = "public" - py_args.param_visibility = "private" - py_args.logrows = 20 + if zk_backend == "ezkl": + py_args = ezkl.PyRunArgs() + py_args.input_visibility = "public" + py_args.output_visibility = "public" + py_args.param_visibility = "private" + py_args.logrows = 20 - if verbose: - print("Generating settings & compiling circuit...") - start_time = time.time() - - # 1) gen_settings + compile_circuit - ezkl.gen_settings(onnx_path, settings_file, py_run_args=py_args) - ezkl.compile_circuit(onnx_path, circuit_name, settings_file) - - # 2) SRS + setup - if not os.path.isfile(srs_file): - ezkl.gen_srs(srs_file, py_args.logrows) - ezkl.setup(circuit_name, vk_file, pk_file, srs_file) - end_time = time.time() - if verbose: - print(f"Setup for {base_name} took {end_time - start_time:.2f} sec") - total_settings_time += end_time - start_time + if verbose: + print("Generating settings & compiling circuit...") + start_time = time.time() + + # 1) gen_settings + compile_circuit + ezkl.gen_settings(onnx_path, settings_file, py_run_args=py_args) + ezkl.compile_circuit(onnx_path, circuit_name, settings_file) + + # 2) SRS + setup + if not os.path.isfile(srs_file): + ezkl.gen_srs(srs_file, py_args.logrows) + ezkl.setup(circuit_name, vk_file, pk_file, srs_file) + end_time = time.time() + if verbose: + print(f"Setup for {base_name} took {end_time - start_time:.2f} sec") + total_settings_time += end_time - start_time - # Local check - with open(json_path, "r") as f: - data = json.load(f) - input_array = np.array(data["input_data"], dtype=np.float32) - if verbose: - print("Input shape from JSON:", input_array.shape) - session = onnxruntime.InferenceSession(onnx_path) - out = session.run(None, {"input_x": input_array}) - if verbose: - print("Local ONNX output shape:", out[0].shape) + # Local check + with open(json_path, "r") as f: + data = json.load(f) + input_array = np.array(data["input_data"], dtype=np.float32) + if verbose: + print("Input shape from JSON:", input_array.shape) + session = onnxruntime.InferenceSession(onnx_path) + out = session.run(None, {"input_x": input_array}) + if verbose: + print("Local ONNX output shape:", out[0].shape) - # 3) gen_witness (async via background thread) - if verbose: - print("Generating witness (async)...") - start_time = time.time() - try: - # Offload blocking ezkl.gen_witness call to a worker thread so that - # the async event loop can continue making progress. - await asyncio.to_thread( - ezkl.gen_witness, - data=json_path, - model=circuit_name, - output=witness_file, + # 3) gen_witness (async) + if verbose: + print("Generating witness (async)...") + start_time = time.time() + try: + await ezkl.gen_witness( + data=json_path, model=circuit_name, output=witness_file + ) + except RuntimeError as e: + print(f"Failed to generate witness: {e}") + continue + + if not ezkl.mock(witness_file, circuit_name): + print("Mock run failed, skipping.") + continue + + end_time = time.time() + if verbose: + print(f"Witness gen took {end_time - start_time:.2f} sec") + total_witness_time += end_time - start_time + # 4) prove + if verbose: + print("Generating proof...") + start_time = time.time() + prove_ok = ezkl.prove( + witness_file, circuit_name, pk_file, proof_file, "single", srs_file ) - except RuntimeError as e: - print(f"Failed to generate witness: {e}") - continue - - if not ezkl.mock(witness_file, circuit_name): - print("Mock run failed, skipping.") - continue - - end_time = time.time() - if verbose: - print(f"Witness gen took {end_time - start_time:.2f} sec") - total_witness_time += end_time - start_time - # 4) prove - if verbose: - print("Generating proof...") - start_time = time.time() - prove_ok = ezkl.prove( - witness_file, circuit_name, pk_file, proof_file, "single", srs_file - ) - end_time = time.time() - if verbose: - print(f"Proof gen took {end_time - start_time:.2f} sec") - total_prove_time += end_time - start_time + end_time = time.time() + if verbose: + print(f"Proof gen took {end_time - start_time:.2f} sec") + total_prove_time += end_time - start_time - if not prove_ok: - print(f"Proof generation failed for {base_name}") - continue + if not prove_ok: + print(f"Proof generation failed for {base_name}") + continue - if verbose: - print(f"Done with {base_name}.\n") - os.remove(pk_file) - count_onnx_files += 1 + if verbose: + print(f"Done with {base_name}.\n") + os.remove(pk_file) + count_onnx_files += 1 + elif zk_backend == "plonky3": # 1. load the model + tensors = { + init.name: numpy_helper.to_array(init) # 2. convert to NumPy + for init in onnx_model.graph.initializer + } + + A = tensors["A"] # LoRA A (in_dim × r) + B = tensors["B"] # LoRA B (r × out_dim) + + A_t = torch.from_numpy(A) # or load directly as torch.Tensor + B_t = torch.from_numpy(B) + + W = A_t @ B_t + m = W.shape[0] + n = W.shape[1] + + print("m =", m, "n =", n) + + print("A.shape =", A.shape, " B.shape =", B.shape) + print("W.shape =", W.shape) + + W = W.tolist() + W_encoded = [fixed_point_encode(row, fractional_bits=24) for row in W] + + print("W_encoded.shape =", len(W_encoded), len(W_encoded[0])) + + # Read input data + with open(json_path, "r") as f: + input_data = json.load(f) + + # Flatten to 1D with correct shape + x = np.array(input_data["input_data"], dtype=np.float32)[0] + x_2d = x.reshape(-1, m) # shape: (batch*seq_len, W.shape[0]) + print("batch x tokens × hidden:", x_2d.shape) + + count_proofs = 0 + for i in range(len(x_2d)): + v = x_2d[i].tolist() + v_encoded = fixed_point_encode(v, fractional_bits=24) + print("v_encoded.shape =", len(v_encoded)) + start_time = time.time() + pl.vector_matrix_multiplication_prove(m, n, v_encoded, W_encoded) + end_time = time.time() + if verbose: + print(f"Proof gen took {end_time - start_time:.2f} sec") + total_prove_time += end_time - start_time + count_proofs += 1 + if count_proofs >= 1: + break + + else: + raise ValueError(f"Invalid ZK backend: {zk_backend}") return ( total_settings_time,