1- use p3_air:: { Air , AirBuilder , BaseAir } ;
2- use p3_challenger:: HashChallenger ;
3- use p3_challenger:: SerializingChallenger32 ;
4- use p3_circle:: CirclePcs ;
5- use p3_commit:: ExtensionMmcs ;
6- use p3_field:: { extension:: BinomialExtensionField , PrimeCharacteristicRing , PrimeField } ;
7- use p3_fri:: FriParameters ;
8- use p3_keccak:: Keccak256Hash ;
1+ use p3_field:: { PrimeCharacteristicRing , PrimeField } ;
92use p3_matrix:: { dense:: RowMajorMatrix , Matrix } ;
10- use p3_merkle_tree:: MerkleTreeMmcs ;
113use p3_mersenne_31:: Mersenne31 ;
12- use p3_symmetric:: CompressionFunctionFromHasher ;
13- use p3_symmetric:: SerializingHasher ;
14- use p3_uni_stark:: { prove, verify, StarkConfig } ;
4+ use p3_uni_stark:: { prove, verify} ;
155use pyo3:: prelude:: * ;
166use pyo3:: types:: PyModule ;
177use pyo3:: wrap_pyfunction;
188use pyo3:: Bound ;
9+
10+ pub mod vector_matrix_air;
11+
12+ use vector_matrix_air:: { VectorMatrixMultiplicationAIR , MyConfig } ;
13+
1914/// Multiplies a vector by a matrix (vector * matrix).
2015///
2116/// # Arguments
@@ -42,368 +37,6 @@ pub fn vector_matrix_multiply<F: PrimeField>(a: &Vec<F>, b: &RowMajorMatrix<F>)
4237 result
4338}
4439
45- /// AIR (Algebraic Intermediate Representation) for vector-matrix multiplication.
46- /// This struct represents the configuration and constraints for proving vector-matrix multiplication
47- /// using an algebraic execution trace. It tracks the dimensions of the input matrices
48- /// and provides methods for generating and verifying the computation trace.
49- pub struct VectorMatrixMultiplicationAIR < F : PrimeField > {
50- /// Length of vector (number of rows in matrix)
51- pub m : usize ,
52- /// Number of columns in matrix
53- pub n : usize ,
54-
55- pub byte_hash : ByteHash ,
56- pub field_hash : FieldHash ,
57- pub compress : MyCompress ,
58- pub val_mmcs : ValMmcs ,
59- pub challenge_mmcs : ChallengeMmcs ,
60- pub config : MyConfig ,
61-
62- /// Field element type
63- _phantom : std:: marker:: PhantomData < F > ,
64- }
65-
66- // Field
67- type Val = Mersenne31 ;
68-
69- // This creates a cubic extension field over Val using a binomial basis. It's used for generating challenges in the proof system.
70- // The reason why we want to extend our field for Challenges, is because the original Field size is too small to be brute-forced to solve the challenge.
71- type Challenge = BinomialExtensionField < Val , 3 > ;
72- // Your choice of Hash Function
73- type ByteHash = Keccak256Hash ;
74- // A serializer for Hash function, so that it can take Fields as inputs
75- type FieldHash = SerializingHasher < ByteHash > ;
76- // Defines a compression function type using ByteHash, with 2 input blocks and 32-byte output.
77- type MyCompress = CompressionFunctionFromHasher < ByteHash , 2 , 32 > ;
78- // Defines a Merkle tree commitment scheme for field elements with 32 levels.
79- type ValMmcs = MerkleTreeMmcs < Val , u8 , FieldHash , MyCompress , 32 > ;
80- // Defines an extension of the Merkle tree commitment scheme for the challenge field.
81- type ChallengeMmcs = ExtensionMmcs < Val , Challenge , ValMmcs > ;
82- // Defines the challenger type for generating random challenges.
83- type Challenger = SerializingChallenger32 < Val , HashChallenger < u8 , ByteHash , 32 > > ;
84- // Defines the polynomial commitment scheme type.
85- type Pcs = CirclePcs < Val , ValMmcs , ChallengeMmcs > ;
86- // Defines the overall STARK configuration type.
87- type MyConfig = StarkConfig < Pcs , Challenge , Challenger > ;
88-
89- impl < F : PrimeField > VectorMatrixMultiplicationAIR < F > {
90- pub fn new ( m : usize , n : usize ) -> Self {
91- // Declaring an empty hash and its serializer.
92- let byte_hash = ByteHash { } ;
93- // Declaring Field hash function, it is used to hash field elements in the proof system
94- let field_hash = FieldHash :: new ( byte_hash) ;
95- // Creates a new instance of the compression function.
96- let compress = MyCompress :: new ( byte_hash) ;
97- // Instantiates the Merkle tree commitment scheme.
98- let val_mmcs = ValMmcs :: new ( field_hash, compress. clone ( ) ) ;
99- // Creates an instance of the challenge Merkle tree commitment scheme.
100- let challenge_mmcs = ChallengeMmcs :: new ( val_mmcs. clone ( ) ) ;
101- // Configures the FRI (Fast Reed-Solomon IOP) protocol parameters.
102- let fri_config = FriParameters {
103- log_blowup : 1 ,
104- num_queries : 100 ,
105- proof_of_work_bits : 16 ,
106- mmcs : challenge_mmcs. clone ( ) ,
107- log_final_poly_len : 1 ,
108- } ;
109- // Instantiates the polynomial commitment scheme with the above parameters.
110- let pcs = Pcs {
111- mmcs : val_mmcs. clone ( ) ,
112- fri_params : fri_config,
113- _phantom : std:: marker:: PhantomData ,
114- } ;
115- let challenger = Challenger :: from_hasher ( vec ! [ ] , byte_hash) ;
116- // Creates the STARK configuration instance.
117- let config = MyConfig :: new ( pcs, challenger) ;
118-
119- Self {
120- m,
121- n,
122- byte_hash,
123- field_hash,
124- compress,
125- val_mmcs,
126- challenge_mmcs,
127- config,
128- _phantom : std:: marker:: PhantomData ,
129- }
130- }
131-
132- /// Returns the number of columns in the execution trace
133- ///
134- /// The execution trace for matrix multiplication requires columns to store:
135- /// - All elements from vector (m columns)
136- /// - All elements from matrix (n * m columns)
137- /// - m + (m * n) columns for the selectors
138- /// - One column for the running sum of the current element being computed
139- /// - One column for the row in the trace that is enabled
140- ///
141- /// The total width is calculated as: 2 * (m + m * n) + 1
142- pub fn trace_width ( & self ) -> usize {
143- 2 * ( self . m + self . m * self . n ) + 1 + 1
144- }
145-
146- /// Pushes the matrix elements to the trace data with column major order
147- fn push_matrix ( & self , trace_data : & mut Vec < F > , a : & RowMajorMatrix < F > ) {
148- for i in 0 ..self . n {
149- for j in 0 ..self . m {
150- trace_data. push ( a. get ( j, i) . unwrap ( ) ) ;
151- }
152- }
153- }
154-
155- /// Generates a computation trace for vector-matrix multiplication
156- ///
157- /// # Arguments
158- /// * `v` - Input vector (length m) with field elements
159- /// * `a` - Matrix (m x n) with field elements
160- ///
161- /// # Returns
162- /// A matrix representing the execution trace, where each row is a step in the computation
163- /// and columns correspond to:
164- /// - Columns 0..m: All elements from vector v
165- /// - Columns m..(m+m*n): All elements from matrix a (in column-major order)
166- /// - Columns (m+m*n)..(m+m*n+m): Vector selector (one-hot encoding for vector elements)
167- /// - Columns (m+m*n+m)..(m+m*n+m+m*n): Matrix selector (one-hot encoding for matrix elements)
168- /// - Column trace_width()-2: Running sum for the current dot product computation
169- /// - Column trace_width()-1: Enabled flag (1 if row is active, 0 if padding)
170- pub fn generate_trace ( & self , v : & Vec < F > , a : & RowMajorMatrix < F > ) -> RowMajorMatrix < F > {
171- assert_eq ! (
172- a. height( ) ,
173- self . m,
174- "Matrix height should match AIR configuration"
175- ) ;
176- assert_eq ! (
177- a. width( ) ,
178- self . n,
179- "Matrix width should match AIR configuration"
180- ) ;
181- assert_eq ! (
182- v. len( ) ,
183- self . m,
184- "Vector length should match AIR configuration"
185- ) ;
186-
187- // Compute total number of steps needed for the trace
188- // For each element V[i], we need m steps to compute the dot product
189- let total_rows = self . m * self . n ;
190-
191- // Initialize the trace matrix with F elements
192- let mut trace_data: Vec < F > = Vec :: with_capacity ( total_rows * self . trace_width ( ) ) ;
193-
194- let mut vector_selector: Vec < F > = vec ! [ F :: ONE ]
195- . into_iter ( )
196- . chain ( std:: iter:: repeat ( F :: ZERO ) . take ( self . m - 1 ) )
197- . collect ( ) ;
198- let mut matrix_selector: Vec < F > = vec ! [ F :: ONE ]
199- . into_iter ( )
200- . chain ( std:: iter:: repeat ( F :: ZERO ) . take ( self . m * self . n - 1 ) )
201- . collect ( ) ;
202-
203- let mut previous_sum = F :: ZERO ;
204-
205- // Generate the step-by-step trace
206- for _ in 0 ..total_rows {
207- trace_data. extend_from_slice ( v) ;
208- self . push_matrix ( & mut trace_data, a) ;
209- trace_data. extend_from_slice ( & vector_selector) ;
210- trace_data. extend_from_slice ( & matrix_selector) ;
211-
212- // Find the index in vector_selector where the value is F::ONE
213- let vector_index = vector_selector
214- . iter ( )
215- . position ( |& x| x == F :: ONE )
216- . expect ( "vector_selector should contain F::ONE" ) ;
217-
218- // Get the index in matrix_selector where the value is F::ONE
219- let matrix_index = matrix_selector
220- . iter ( )
221- . position ( |& x| x == F :: ONE )
222- . expect ( "matrix_selector should contain F::ONE" ) ;
223-
224- let running_sum = if vector_index > 0 {
225- trace_data[ vector_index] * trace_data[ self . m + matrix_index] + previous_sum
226- } else {
227- trace_data[ vector_index] * trace_data[ self . m + matrix_index]
228- } ;
229- trace_data. push ( running_sum) ;
230- previous_sum = running_sum. clone ( ) ;
231-
232- trace_data. push ( F :: ONE ) ;
233-
234- vector_selector. rotate_right ( 1 ) ;
235- matrix_selector. rotate_right ( 1 ) ;
236- }
237-
238- // If the trace length is not a power of two, pad it with dummy rows of zeros so that
239- // the total number of rows is the next power of two. The last column (running sum)
240- // is explicitly set to 0 for these padding rows to indicate that they are inactive.
241-
242- let width = self . trace_width ( ) ;
243- let padded_rows = total_rows. next_power_of_two ( ) ;
244- if padded_rows > total_rows {
245- for _ in 0 ..( padded_rows - total_rows) {
246- // Push `width` zeros. Since we are inside a PrimeField, F::ZERO is valid.
247- // The last column (running sum) remains 0 as well.
248- trace_data. extend ( std:: iter:: repeat ( F :: ZERO ) . take ( width) ) ;
249- }
250- }
251-
252- RowMajorMatrix :: new ( trace_data, width)
253- }
254- }
255-
256- impl < F : PrimeField > BaseAir < F > for VectorMatrixMultiplicationAIR < F > {
257- fn width ( & self ) -> usize {
258- self . trace_width ( )
259- }
260- }
261-
262- impl < AB : AirBuilder > Air < AB > for VectorMatrixMultiplicationAIR < AB :: F >
263- where
264- AB :: F : PrimeField ,
265- {
266- fn eval ( & self , builder : & mut AB ) {
267- let main = builder. main ( ) ;
268- let current = main. row_slice ( 0 ) . unwrap ( ) ;
269- let next = main. row_slice ( 1 ) . unwrap ( ) ;
270-
271- let v_sel_init = self . n * self . m + self . m ;
272- let m_sel_init = self . n * self . m + self . m + self . m ;
273- let matrix_init = self . m ;
274- let sum = self . trace_width ( ) - 2 ;
275- let enabled = self . trace_width ( ) - 1 ;
276-
277- // Enforce starting state
278- // the row is enabled
279- builder
280- . when_first_row ( )
281- . assert_one ( current[ enabled] . clone ( ) ) ;
282-
283- // sum equal the first element of the vector times the first element of the matrix
284- builder. when_first_row ( ) . assert_eq (
285- current[ sum] . clone ( ) ,
286- current[ 0 ] . clone ( ) * current[ matrix_init] . clone ( ) ,
287- ) ;
288-
289- // The first element of the vector selector is 1
290- builder
291- . when_first_row ( )
292- . assert_one ( current[ v_sel_init] . clone ( ) ) ;
293- // The rest of the vector selectors are 0
294- for i in 1 ..self . m {
295- builder
296- . when_first_row ( )
297- . assert_zero ( current[ v_sel_init + i] . clone ( ) ) ;
298- }
299-
300- // The first element of the matrix selector is 1
301- builder
302- . when_first_row ( )
303- . assert_one ( current[ m_sel_init] . clone ( ) ) ;
304- // The rest of the matrix selectors are 0
305- for i in 1 ..self . m {
306- builder
307- . when_first_row ( )
308- . assert_zero ( current[ m_sel_init + i] . clone ( ) ) ;
309- }
310-
311- // Enforce final enabled row is followed by disabled row
312- builder
313- . when_transition ( )
314- . when ( current[ enabled] . clone ( ) )
315- . when ( current[ sum - 1 ] . clone ( ) )
316- . assert_zero ( next[ enabled] . clone ( ) ) ;
317-
318- // Enforce rows are all 0 in the last column after the last enabled row
319- builder
320- . when_transition ( )
321- . when ( AB :: Expr :: ONE - current[ enabled] . clone ( ) )
322- . assert_zero ( next[ enabled] . clone ( ) ) ;
323-
324- // Enforce booleanity of the vector selector
325- for i in 0 ..self . m {
326- builder
327- . when_transition ( )
328- . when ( current[ enabled] . clone ( ) )
329- . assert_bool ( current[ v_sel_init + i] . clone ( ) ) ;
330- }
331-
332- // Enforce booleanity of the matrix selector
333- for i in 0 ..self . m {
334- builder
335- . when_transition ( )
336- . when ( current[ enabled] . clone ( ) )
337- . assert_bool ( current[ m_sel_init + i] . clone ( ) ) ;
338- }
339-
340- // Enforce booleanity of the enabled colum
341- builder
342- . when_transition ( )
343- . assert_bool ( current[ enabled] . clone ( ) ) ;
344- builder. when_last_row ( ) . assert_bool ( next[ enabled] . clone ( ) ) ;
345-
346- // Enforce the sum of the vector selector is 1
347- let mut acum = AB :: Expr :: ZERO ;
348- for i in 0 ..self . m {
349- acum += current[ v_sel_init + i] . clone ( ) ;
350- }
351- builder
352- . when_transition ( )
353- . when ( current[ enabled] . clone ( ) )
354- . assert_eq ( acum, AB :: Expr :: ONE ) ;
355-
356- // Enforce the sum of the matrix selector is 1
357- let mut acum = AB :: Expr :: ZERO ;
358- for i in 0 ..self . m * self . n {
359- acum += current[ m_sel_init + i] . clone ( ) ;
360- }
361- builder
362- . when_transition ( )
363- . when ( current[ enabled] . clone ( ) )
364- . assert_eq ( acum, AB :: Expr :: ONE ) ;
365-
366- // Enforce the vector and matrix do not change between rows
367- for i in 0 ..self . m + self . m * self . n {
368- builder
369- . when_transition ( )
370- . when ( current[ enabled] . clone ( ) )
371- . when ( next[ enabled] . clone ( ) )
372- . assert_eq ( current[ i] . clone ( ) , next[ i] . clone ( ) ) ;
373- }
374-
375- // Enforce the correct vector-matrix multiplication result
376- // If the first element of the vector selector is 1, then
377- // the sum colum does not accumulate from the previous row
378- for i in 0 ..self . m * self . n {
379- builder
380- . when_transition ( )
381- . when ( current[ enabled] . clone ( ) )
382- . when ( current[ v_sel_init] . clone ( ) )
383- . when ( current[ m_sel_init + i] . clone ( ) )
384- . assert_eq (
385- current[ sum] . clone ( ) ,
386- current[ 0 ] . clone ( ) * current[ matrix_init + i] . clone ( ) ,
387- ) ;
388- }
389- // If the first element of the vector selector is 0, then
390- // the sum colum accumulates from the previous row
391- for i in 1 ..self . m {
392- for j in 0 ..self . m * self . n {
393- builder
394- . when_transition ( )
395- . when ( next[ enabled] . clone ( ) )
396- . when ( AB :: Expr :: ONE - next[ v_sel_init] . clone ( ) )
397- . when ( next[ v_sel_init + i] . clone ( ) )
398- . when ( next[ m_sel_init + j] . clone ( ) )
399- . assert_eq (
400- next[ sum] . clone ( ) ,
401- current[ sum] . clone ( ) + next[ i] . clone ( ) * next[ matrix_init + j] . clone ( ) ,
402- ) ;
403- }
404- }
405- }
406- }
40740
40841fn vector_matrix_transform (
40942 m : usize ,
0 commit comments