Skip to content

Commit 511ac23

Browse files
committed
Refactor vector-matrix multiplication implementation by extracting the VectorMatrixMultiplicationAIR struct into a new module. Simplify imports and remove unused code, enhancing clarity and maintainability of the proof strategy.
1 parent 3cbcbf4 commit 511ac23

2 files changed

Lines changed: 385 additions & 374 deletions

File tree

src/zklora/libs/plonky3/src/lib.rs

Lines changed: 7 additions & 374 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
1-
use p3_air::{Air, AirBuilder, BaseAir};
2-
use p3_challenger::HashChallenger;
3-
use p3_challenger::SerializingChallenger32;
4-
use p3_circle::CirclePcs;
5-
use p3_commit::ExtensionMmcs;
6-
use p3_field::{extension::BinomialExtensionField, PrimeCharacteristicRing, PrimeField};
7-
use p3_fri::FriParameters;
8-
use p3_keccak::Keccak256Hash;
1+
use p3_field::{PrimeCharacteristicRing, PrimeField};
92
use p3_matrix::{dense::RowMajorMatrix, Matrix};
10-
use p3_merkle_tree::MerkleTreeMmcs;
113
use p3_mersenne_31::Mersenne31;
12-
use p3_symmetric::CompressionFunctionFromHasher;
13-
use p3_symmetric::SerializingHasher;
14-
use p3_uni_stark::{prove, verify, StarkConfig};
4+
use p3_uni_stark::{prove, verify};
155
use pyo3::prelude::*;
166
use pyo3::types::PyModule;
177
use pyo3::wrap_pyfunction;
188
use pyo3::Bound;
9+
10+
pub mod vector_matrix_air;
11+
12+
use vector_matrix_air::{VectorMatrixMultiplicationAIR, MyConfig};
13+
1914
/// Multiplies a vector by a matrix (vector * matrix).
2015
///
2116
/// # Arguments
@@ -42,368 +37,6 @@ pub fn vector_matrix_multiply<F: PrimeField>(a: &Vec<F>, b: &RowMajorMatrix<F>)
4237
result
4338
}
4439

45-
/// AIR (Algebraic Intermediate Representation) for vector-matrix multiplication.
46-
/// This struct represents the configuration and constraints for proving vector-matrix multiplication
47-
/// using an algebraic execution trace. It tracks the dimensions of the input matrices
48-
/// and provides methods for generating and verifying the computation trace.
49-
pub struct VectorMatrixMultiplicationAIR<F: PrimeField> {
50-
/// Length of vector (number of rows in matrix)
51-
pub m: usize,
52-
/// Number of columns in matrix
53-
pub n: usize,
54-
55-
pub byte_hash: ByteHash,
56-
pub field_hash: FieldHash,
57-
pub compress: MyCompress,
58-
pub val_mmcs: ValMmcs,
59-
pub challenge_mmcs: ChallengeMmcs,
60-
pub config: MyConfig,
61-
62-
/// Field element type
63-
_phantom: std::marker::PhantomData<F>,
64-
}
65-
66-
// Field
67-
type Val = Mersenne31;
68-
69-
// This creates a cubic extension field over Val using a binomial basis. It's used for generating challenges in the proof system.
70-
// The reason why we want to extend our field for Challenges, is because the original Field size is too small to be brute-forced to solve the challenge.
71-
type Challenge = BinomialExtensionField<Val, 3>;
72-
// Your choice of Hash Function
73-
type ByteHash = Keccak256Hash;
74-
// A serializer for Hash function, so that it can take Fields as inputs
75-
type FieldHash = SerializingHasher<ByteHash>;
76-
// Defines a compression function type using ByteHash, with 2 input blocks and 32-byte output.
77-
type MyCompress = CompressionFunctionFromHasher<ByteHash, 2, 32>;
78-
// Defines a Merkle tree commitment scheme for field elements with 32 levels.
79-
type ValMmcs = MerkleTreeMmcs<Val, u8, FieldHash, MyCompress, 32>;
80-
// Defines an extension of the Merkle tree commitment scheme for the challenge field.
81-
type ChallengeMmcs = ExtensionMmcs<Val, Challenge, ValMmcs>;
82-
// Defines the challenger type for generating random challenges.
83-
type Challenger = SerializingChallenger32<Val, HashChallenger<u8, ByteHash, 32>>;
84-
// Defines the polynomial commitment scheme type.
85-
type Pcs = CirclePcs<Val, ValMmcs, ChallengeMmcs>;
86-
// Defines the overall STARK configuration type.
87-
type MyConfig = StarkConfig<Pcs, Challenge, Challenger>;
88-
89-
impl<F: PrimeField> VectorMatrixMultiplicationAIR<F> {
90-
pub fn new(m: usize, n: usize) -> Self {
91-
// Declaring an empty hash and its serializer.
92-
let byte_hash = ByteHash {};
93-
// Declaring Field hash function, it is used to hash field elements in the proof system
94-
let field_hash = FieldHash::new(byte_hash);
95-
// Creates a new instance of the compression function.
96-
let compress = MyCompress::new(byte_hash);
97-
// Instantiates the Merkle tree commitment scheme.
98-
let val_mmcs = ValMmcs::new(field_hash, compress.clone());
99-
// Creates an instance of the challenge Merkle tree commitment scheme.
100-
let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone());
101-
// Configures the FRI (Fast Reed-Solomon IOP) protocol parameters.
102-
let fri_config = FriParameters {
103-
log_blowup: 1,
104-
num_queries: 100,
105-
proof_of_work_bits: 16,
106-
mmcs: challenge_mmcs.clone(),
107-
log_final_poly_len: 1,
108-
};
109-
// Instantiates the polynomial commitment scheme with the above parameters.
110-
let pcs = Pcs {
111-
mmcs: val_mmcs.clone(),
112-
fri_params: fri_config,
113-
_phantom: std::marker::PhantomData,
114-
};
115-
let challenger = Challenger::from_hasher(vec![], byte_hash);
116-
// Creates the STARK configuration instance.
117-
let config = MyConfig::new(pcs, challenger);
118-
119-
Self {
120-
m,
121-
n,
122-
byte_hash,
123-
field_hash,
124-
compress,
125-
val_mmcs,
126-
challenge_mmcs,
127-
config,
128-
_phantom: std::marker::PhantomData,
129-
}
130-
}
131-
132-
/// Returns the number of columns in the execution trace
133-
///
134-
/// The execution trace for matrix multiplication requires columns to store:
135-
/// - All elements from vector (m columns)
136-
/// - All elements from matrix (n * m columns)
137-
/// - m + (m * n) columns for the selectors
138-
/// - One column for the running sum of the current element being computed
139-
/// - One column for the row in the trace that is enabled
140-
///
141-
/// The total width is calculated as: 2 * (m + m * n) + 1
142-
pub fn trace_width(&self) -> usize {
143-
2 * (self.m + self.m * self.n) + 1 + 1
144-
}
145-
146-
/// Pushes the matrix elements to the trace data with column major order
147-
fn push_matrix(&self, trace_data: &mut Vec<F>, a: &RowMajorMatrix<F>) {
148-
for i in 0..self.n {
149-
for j in 0..self.m {
150-
trace_data.push(a.get(j, i).unwrap());
151-
}
152-
}
153-
}
154-
155-
/// Generates a computation trace for vector-matrix multiplication
156-
///
157-
/// # Arguments
158-
/// * `v` - Input vector (length m) with field elements
159-
/// * `a` - Matrix (m x n) with field elements
160-
///
161-
/// # Returns
162-
/// A matrix representing the execution trace, where each row is a step in the computation
163-
/// and columns correspond to:
164-
/// - Columns 0..m: All elements from vector v
165-
/// - Columns m..(m+m*n): All elements from matrix a (in column-major order)
166-
/// - Columns (m+m*n)..(m+m*n+m): Vector selector (one-hot encoding for vector elements)
167-
/// - Columns (m+m*n+m)..(m+m*n+m+m*n): Matrix selector (one-hot encoding for matrix elements)
168-
/// - Column trace_width()-2: Running sum for the current dot product computation
169-
/// - Column trace_width()-1: Enabled flag (1 if row is active, 0 if padding)
170-
pub fn generate_trace(&self, v: &Vec<F>, a: &RowMajorMatrix<F>) -> RowMajorMatrix<F> {
171-
assert_eq!(
172-
a.height(),
173-
self.m,
174-
"Matrix height should match AIR configuration"
175-
);
176-
assert_eq!(
177-
a.width(),
178-
self.n,
179-
"Matrix width should match AIR configuration"
180-
);
181-
assert_eq!(
182-
v.len(),
183-
self.m,
184-
"Vector length should match AIR configuration"
185-
);
186-
187-
// Compute total number of steps needed for the trace
188-
// For each element V[i], we need m steps to compute the dot product
189-
let total_rows = self.m * self.n;
190-
191-
// Initialize the trace matrix with F elements
192-
let mut trace_data: Vec<F> = Vec::with_capacity(total_rows * self.trace_width());
193-
194-
let mut vector_selector: Vec<F> = vec![F::ONE]
195-
.into_iter()
196-
.chain(std::iter::repeat(F::ZERO).take(self.m - 1))
197-
.collect();
198-
let mut matrix_selector: Vec<F> = vec![F::ONE]
199-
.into_iter()
200-
.chain(std::iter::repeat(F::ZERO).take(self.m * self.n - 1))
201-
.collect();
202-
203-
let mut previous_sum = F::ZERO;
204-
205-
// Generate the step-by-step trace
206-
for _ in 0..total_rows {
207-
trace_data.extend_from_slice(v);
208-
self.push_matrix(&mut trace_data, a);
209-
trace_data.extend_from_slice(&vector_selector);
210-
trace_data.extend_from_slice(&matrix_selector);
211-
212-
// Find the index in vector_selector where the value is F::ONE
213-
let vector_index = vector_selector
214-
.iter()
215-
.position(|&x| x == F::ONE)
216-
.expect("vector_selector should contain F::ONE");
217-
218-
// Get the index in matrix_selector where the value is F::ONE
219-
let matrix_index = matrix_selector
220-
.iter()
221-
.position(|&x| x == F::ONE)
222-
.expect("matrix_selector should contain F::ONE");
223-
224-
let running_sum = if vector_index > 0 {
225-
trace_data[vector_index] * trace_data[self.m + matrix_index] + previous_sum
226-
} else {
227-
trace_data[vector_index] * trace_data[self.m + matrix_index]
228-
};
229-
trace_data.push(running_sum);
230-
previous_sum = running_sum.clone();
231-
232-
trace_data.push(F::ONE);
233-
234-
vector_selector.rotate_right(1);
235-
matrix_selector.rotate_right(1);
236-
}
237-
238-
// If the trace length is not a power of two, pad it with dummy rows of zeros so that
239-
// the total number of rows is the next power of two. The last column (running sum)
240-
// is explicitly set to 0 for these padding rows to indicate that they are inactive.
241-
242-
let width = self.trace_width();
243-
let padded_rows = total_rows.next_power_of_two();
244-
if padded_rows > total_rows {
245-
for _ in 0..(padded_rows - total_rows) {
246-
// Push `width` zeros. Since we are inside a PrimeField, F::ZERO is valid.
247-
// The last column (running sum) remains 0 as well.
248-
trace_data.extend(std::iter::repeat(F::ZERO).take(width));
249-
}
250-
}
251-
252-
RowMajorMatrix::new(trace_data, width)
253-
}
254-
}
255-
256-
impl<F: PrimeField> BaseAir<F> for VectorMatrixMultiplicationAIR<F> {
257-
fn width(&self) -> usize {
258-
self.trace_width()
259-
}
260-
}
261-
262-
impl<AB: AirBuilder> Air<AB> for VectorMatrixMultiplicationAIR<AB::F>
263-
where
264-
AB::F: PrimeField,
265-
{
266-
fn eval(&self, builder: &mut AB) {
267-
let main = builder.main();
268-
let current = main.row_slice(0).unwrap();
269-
let next = main.row_slice(1).unwrap();
270-
271-
let v_sel_init = self.n * self.m + self.m;
272-
let m_sel_init = self.n * self.m + self.m + self.m;
273-
let matrix_init = self.m;
274-
let sum = self.trace_width() - 2;
275-
let enabled = self.trace_width() - 1;
276-
277-
// Enforce starting state
278-
// the row is enabled
279-
builder
280-
.when_first_row()
281-
.assert_one(current[enabled].clone());
282-
283-
// sum equal the first element of the vector times the first element of the matrix
284-
builder.when_first_row().assert_eq(
285-
current[sum].clone(),
286-
current[0].clone() * current[matrix_init].clone(),
287-
);
288-
289-
// The first element of the vector selector is 1
290-
builder
291-
.when_first_row()
292-
.assert_one(current[v_sel_init].clone());
293-
// The rest of the vector selectors are 0
294-
for i in 1..self.m {
295-
builder
296-
.when_first_row()
297-
.assert_zero(current[v_sel_init + i].clone());
298-
}
299-
300-
// The first element of the matrix selector is 1
301-
builder
302-
.when_first_row()
303-
.assert_one(current[m_sel_init].clone());
304-
// The rest of the matrix selectors are 0
305-
for i in 1..self.m {
306-
builder
307-
.when_first_row()
308-
.assert_zero(current[m_sel_init + i].clone());
309-
}
310-
311-
// Enforce final enabled row is followed by disabled row
312-
builder
313-
.when_transition()
314-
.when(current[enabled].clone())
315-
.when(current[sum - 1].clone())
316-
.assert_zero(next[enabled].clone());
317-
318-
// Enforce rows are all 0 in the last column after the last enabled row
319-
builder
320-
.when_transition()
321-
.when(AB::Expr::ONE - current[enabled].clone())
322-
.assert_zero(next[enabled].clone());
323-
324-
// Enforce booleanity of the vector selector
325-
for i in 0..self.m {
326-
builder
327-
.when_transition()
328-
.when(current[enabled].clone())
329-
.assert_bool(current[v_sel_init + i].clone());
330-
}
331-
332-
// Enforce booleanity of the matrix selector
333-
for i in 0..self.m {
334-
builder
335-
.when_transition()
336-
.when(current[enabled].clone())
337-
.assert_bool(current[m_sel_init + i].clone());
338-
}
339-
340-
// Enforce booleanity of the enabled colum
341-
builder
342-
.when_transition()
343-
.assert_bool(current[enabled].clone());
344-
builder.when_last_row().assert_bool(next[enabled].clone());
345-
346-
// Enforce the sum of the vector selector is 1
347-
let mut acum = AB::Expr::ZERO;
348-
for i in 0..self.m {
349-
acum += current[v_sel_init + i].clone();
350-
}
351-
builder
352-
.when_transition()
353-
.when(current[enabled].clone())
354-
.assert_eq(acum, AB::Expr::ONE);
355-
356-
// Enforce the sum of the matrix selector is 1
357-
let mut acum = AB::Expr::ZERO;
358-
for i in 0..self.m * self.n {
359-
acum += current[m_sel_init + i].clone();
360-
}
361-
builder
362-
.when_transition()
363-
.when(current[enabled].clone())
364-
.assert_eq(acum, AB::Expr::ONE);
365-
366-
// Enforce the vector and matrix do not change between rows
367-
for i in 0..self.m + self.m * self.n {
368-
builder
369-
.when_transition()
370-
.when(current[enabled].clone())
371-
.when(next[enabled].clone())
372-
.assert_eq(current[i].clone(), next[i].clone());
373-
}
374-
375-
// Enforce the correct vector-matrix multiplication result
376-
// If the first element of the vector selector is 1, then
377-
// the sum colum does not accumulate from the previous row
378-
for i in 0..self.m * self.n {
379-
builder
380-
.when_transition()
381-
.when(current[enabled].clone())
382-
.when(current[v_sel_init].clone())
383-
.when(current[m_sel_init + i].clone())
384-
.assert_eq(
385-
current[sum].clone(),
386-
current[0].clone() * current[matrix_init + i].clone(),
387-
);
388-
}
389-
// If the first element of the vector selector is 0, then
390-
// the sum colum accumulates from the previous row
391-
for i in 1..self.m {
392-
for j in 0..self.m * self.n {
393-
builder
394-
.when_transition()
395-
.when(next[enabled].clone())
396-
.when(AB::Expr::ONE - next[v_sel_init].clone())
397-
.when(next[v_sel_init + i].clone())
398-
.when(next[m_sel_init + j].clone())
399-
.assert_eq(
400-
next[sum].clone(),
401-
current[sum].clone() + next[i].clone() * next[matrix_init + j].clone(),
402-
);
403-
}
404-
}
405-
}
406-
}
40740

40841
fn vector_matrix_transform(
40942
m: usize,

0 commit comments

Comments
 (0)