|
1 | 1 | use p3_air::{Air, BaseAir, ExtensionBuilder, WindowAccess}; |
2 | 2 | use p3_field::{PrimeCharacteristicRing, batch_multiplicative_inverse}; |
3 | 3 | use p3_matrix::{Matrix, dense::RowMajorMatrix}; |
| 4 | +use p3_maybe_rayon::prelude::*; |
4 | 5 |
|
5 | 6 | use crate::{ |
6 | 7 | builder::{TwoStagedBuilder, symbolic::SymbolicExpression}, |
@@ -113,29 +114,23 @@ impl Lookup<Val> { |
113 | 114 | fingerprint_challenge: &ExtVal, |
114 | 115 | mut accumulator: ExtVal, |
115 | 116 | ) -> (Vec<RowMajorMatrix<ExtVal>>, Vec<ExtVal>) { |
116 | | - // Collect the number of lookups per circuit while accumulating the total |
117 | | - // number of lookups. |
118 | | - let mut num_lookups_per_circuit = Vec::with_capacity(lookups.len()); |
119 | | - let mut total_num_lookups = 0; |
120 | | - for circuit_lookups in lookups { |
121 | | - let num_rows = circuit_lookups.len(); |
122 | | - // Every row is assumed to have the same number of lookups, which is |
123 | | - // the number of lookups of the first row. |
124 | | - let num_row_lookups = circuit_lookups[0].len(); |
125 | | - let num_circuit_lookups = num_rows * num_row_lookups; |
126 | | - num_lookups_per_circuit.push(num_circuit_lookups); |
127 | | - total_num_lookups += num_circuit_lookups; |
128 | | - } |
| 117 | + // Number of lookups per circuit. Every row in a circuit is assumed to |
| 118 | + // have the same number of lookups (the lookups are expected to be fully |
| 119 | + // padded), so this is taken from the first row. |
| 120 | + let num_lookups_per_circuit: Vec<usize> = lookups |
| 121 | + .iter() |
| 122 | + .map(|circuit_lookups| circuit_lookups.len() * circuit_lookups[0].len()) |
| 123 | + .collect(); |
129 | 124 |
|
130 | | - // Compute and collect all messages. There's one message per lookup. |
131 | | - let mut messages = Vec::with_capacity(total_num_lookups); |
132 | | - for circuit_lookups in lookups { |
133 | | - let circuit_messages = circuit_lookups |
134 | | - .iter() |
135 | | - .flatten() |
136 | | - .map(|lookup| lookup.compute_message(lookup_challenge, fingerprint_challenge)); |
137 | | - messages.extend(circuit_messages); |
138 | | - } |
| 125 | + // Compute the message for each lookup, in flat circuit-major order. |
| 126 | + // Flatten the references serially first so the parallel map operates |
| 127 | + // on an indexed slice and `collect` can write straight into the |
| 128 | + // output Vec without tree-reducing worker buffers. |
| 129 | + let flat: Vec<&Self> = lookups.iter().flatten().flatten().collect(); |
| 130 | + let messages: Vec<ExtVal> = flat |
| 131 | + .par_iter() |
| 132 | + .map(|lookup| lookup.compute_message(lookup_challenge, fingerprint_challenge)) |
| 133 | + .collect(); |
139 | 134 |
|
140 | 135 | // Compute the inverses of all messages in batch. |
141 | 136 | let messages_inverses = batch_multiplicative_inverse(&messages); |
|
0 commit comments