Skip to content

Commit 65387a2

Browse files
Parallelize lookup message collection (#52)
Replace the serial loop that computes per-lookup messages with a rayon parallel iteration over a preallocated flat slice of lookup references. Flattening serially first lets `collect` write straight into the output Vec without tree-reducing worker buffers.
1 parent aef0a00 commit 65387a2

1 file changed

Lines changed: 17 additions & 22 deletions

File tree

src/lookup.rs

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use p3_air::{Air, BaseAir, ExtensionBuilder, WindowAccess};
22
use p3_field::{PrimeCharacteristicRing, batch_multiplicative_inverse};
33
use p3_matrix::{Matrix, dense::RowMajorMatrix};
4+
use p3_maybe_rayon::prelude::*;
45

56
use crate::{
67
builder::{TwoStagedBuilder, symbolic::SymbolicExpression},
@@ -113,29 +114,23 @@ impl Lookup<Val> {
113114
fingerprint_challenge: &ExtVal,
114115
mut accumulator: ExtVal,
115116
) -> (Vec<RowMajorMatrix<ExtVal>>, Vec<ExtVal>) {
116-
// Collect the number of lookups per circuit while accumulating the total
117-
// number of lookups.
118-
let mut num_lookups_per_circuit = Vec::with_capacity(lookups.len());
119-
let mut total_num_lookups = 0;
120-
for circuit_lookups in lookups {
121-
let num_rows = circuit_lookups.len();
122-
// Every row is assumed to have the same number of lookups, which is
123-
// the number of lookups of the first row.
124-
let num_row_lookups = circuit_lookups[0].len();
125-
let num_circuit_lookups = num_rows * num_row_lookups;
126-
num_lookups_per_circuit.push(num_circuit_lookups);
127-
total_num_lookups += num_circuit_lookups;
128-
}
117+
// Number of lookups per circuit. Every row in a circuit is assumed to
118+
// have the same number of lookups (the lookups are expected to be fully
119+
// padded), so this is taken from the first row.
120+
let num_lookups_per_circuit: Vec<usize> = lookups
121+
.iter()
122+
.map(|circuit_lookups| circuit_lookups.len() * circuit_lookups[0].len())
123+
.collect();
129124

130-
// Compute and collect all messages. There's one message per lookup.
131-
let mut messages = Vec::with_capacity(total_num_lookups);
132-
for circuit_lookups in lookups {
133-
let circuit_messages = circuit_lookups
134-
.iter()
135-
.flatten()
136-
.map(|lookup| lookup.compute_message(lookup_challenge, fingerprint_challenge));
137-
messages.extend(circuit_messages);
138-
}
125+
// Compute the message for each lookup, in flat circuit-major order.
126+
// Flatten the references serially first so the parallel map operates
127+
// on an indexed slice and `collect` can write straight into the
128+
// output Vec without tree-reducing worker buffers.
129+
let flat: Vec<&Self> = lookups.iter().flatten().flatten().collect();
130+
let messages: Vec<ExtVal> = flat
131+
.par_iter()
132+
.map(|lookup| lookup.compute_message(lookup_challenge, fingerprint_challenge))
133+
.collect();
139134

140135
// Compute the inverses of all messages in batch.
141136
let messages_inverses = batch_multiplicative_inverse(&messages);

0 commit comments

Comments
 (0)