Skip to content

Commit ac5d1b4

Browse files
Reduce Context Operations in Native Sumcheck (#20)
* Remove ctx reads * fmt * fix bugs in air and cpu tracegen * fix gpu tracegen bug * clippy --------- Co-authored-by: kunxian xia <xiakunxian130@gmail.com>
1 parent 554df1c commit ac5d1b4

7 files changed

Lines changed: 67 additions & 126 deletions

File tree

extensions/native/circuit/cuda/include/native/sumcheck.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ using namespace native;
88
template <typename T> struct HeaderSpecificCols {
99
T pc;
1010
T registers[5];
11-
MemoryReadAuxCols<T> read_records[7];
11+
MemoryReadAuxCols<T> read_records[8];
1212
MemoryWriteAuxCols<T, EXT_DEG> write_records;
1313
};
1414

1515
template <typename T> struct ProdSpecificCols {
1616
T data_ptr;
1717
T p[EXT_DEG * 2];
18-
MemoryReadAuxCols<T> read_records[2];
18+
MemoryReadAuxCols<T> read_records[1];
1919
T p_evals[EXT_DEG];
2020
MemoryWriteAuxCols<T, EXT_DEG> write_record;
2121
T eval_rlc[EXT_DEG];
@@ -24,7 +24,7 @@ template <typename T> struct ProdSpecificCols {
2424
template <typename T> struct LogupSpecificCols {
2525
T data_ptr;
2626
T pq[EXT_DEG * 4];
27-
MemoryReadAuxCols<T> read_records[2];
27+
MemoryReadAuxCols<T> read_records[1];
2828
T p_evals[EXT_DEG];
2929
T q_evals[EXT_DEG];
3030
MemoryWriteAuxCols<T, EXT_DEG> write_records[2];

extensions/native/circuit/cuda/src/sumcheck.cu

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h
1111
uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32();
1212

1313
if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) {
14-
for (uint32_t i = 0; i < 7; ++i) {
14+
for (uint32_t i = 0; i < 8; ++i) {
1515
mem_fill_base(
1616
mem_helper,
1717
start_timestamp + i,
@@ -25,43 +25,33 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h
2525
specific.slice_from(COL_INDEX(HeaderSpecificCols, write_records.base))
2626
);
2727
} else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) {
28-
mem_fill_base(
29-
mem_helper,
30-
start_timestamp,
31-
specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base))
32-
);
3328
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
3429
mem_fill_base(
3530
mem_helper,
36-
start_timestamp + 1,
37-
specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[1].base))
31+
start_timestamp,
32+
specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base))
3833
);
3934
mem_fill_base(
4035
mem_helper,
41-
start_timestamp + 2,
36+
start_timestamp + 1,
4237
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
4338
);
4439
}
4540
} else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) {
46-
mem_fill_base(
47-
mem_helper,
48-
start_timestamp,
49-
specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base))
50-
);
5141
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
5242
mem_fill_base(
5343
mem_helper,
54-
start_timestamp + 1,
55-
specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[1].base))
44+
start_timestamp,
45+
specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base))
5646
);
5747
mem_fill_base(
5848
mem_helper,
59-
start_timestamp + 2,
49+
start_timestamp + 1,
6050
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
6151
);
6252
mem_fill_base(
6353
mem_helper,
64-
start_timestamp + 3,
54+
start_timestamp + 2,
6555
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
6656
);
6757
}

extensions/native/circuit/src/sumcheck/air.rs

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
167167
alpha,
168168
next.alpha,
169169
);
170+
builder
171+
.when(next.prod_row + next.logup_row)
172+
.assert_eq(max_round, next.max_round);
170173
builder
171174
.when(next.prod_row + next.logup_row)
172175
.assert_eq(prod_nested_len, next.prod_nested_len);
@@ -223,21 +226,21 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
223226
.when(next.prod_row + next.logup_row)
224227
.assert_eq(
225228
next.start_timestamp,
226-
start_timestamp + AB::F::from_canonical_usize(7),
229+
start_timestamp + AB::F::from_canonical_usize(8),
227230
);
228231
builder
229232
.when(prod_row)
230233
.when(next.prod_row + next.logup_row)
231234
.assert_eq(
232235
next.start_timestamp,
233-
start_timestamp + AB::F::ONE + within_round_limit * AB::F::TWO,
236+
start_timestamp + within_round_limit * AB::F::TWO,
234237
);
235238
builder
236239
.when(logup_row)
237240
.when(next.prod_row + next.logup_row)
238241
.assert_eq(
239242
next.start_timestamp,
240-
start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3),
243+
start_timestamp + within_round_limit * AB::F::from_canonical_usize(3),
241244
);
242245

243246
// Termination condition
@@ -330,6 +333,19 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
330333
)
331334
.eval(builder, header_row);
332335

336+
// Read max_round
337+
self.memory_bridge
338+
.read(
339+
MemoryAddress::new(
340+
native_as,
341+
register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN),
342+
),
343+
[max_round],
344+
first_timestamp + AB::F::from_canonical_usize(7),
345+
&header_row_specific.read_records[7],
346+
)
347+
.eval(builder, header_row);
348+
333349
// Write final result
334350
self.memory_bridge
335351
.write(
@@ -348,20 +364,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
348364
let next_prod_row_specific: &ProdSpecificCols<AB::Var> =
349365
next.specific[..ProdSpecificCols::<AB::Var>::width()].borrow();
350366

351-
self.memory_bridge
352-
.read(
353-
MemoryAddress::new(
354-
native_as,
355-
register_ptrs[0]
356-
+ AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN)
357-
+ (curr_prod_n - AB::F::ONE),
358-
), // curr_prod_n starts at 1.
359-
[max_round],
360-
start_timestamp,
361-
&prod_row_specific.read_records[0],
362-
)
363-
.eval(builder, prod_row);
364-
365367
// prod_row * within_round_limit =
366368
// prod_in_round_evaluation + prod_next_round_evaluation
367369
builder
@@ -385,8 +387,8 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
385387
.read(
386388
MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr),
387389
prod_row_specific.p,
388-
start_timestamp + AB::F::ONE,
389-
&prod_row_specific.read_records[1],
390+
start_timestamp,
391+
&prod_row_specific.read_records[0],
390392
)
391393
.eval(builder, prod_row * within_round_limit);
392394

@@ -402,7 +404,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
402404
register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG),
403405
),
404406
prod_row_specific.p_evals,
405-
start_timestamp + AB::F::TWO,
407+
start_timestamp + AB::F::ONE,
406408
&prod_row_specific.write_record,
407409
)
408410
.eval(builder, prod_row * within_round_limit);
@@ -449,21 +451,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
449451
let next_logup_row_specfic: &LogupSpecificCols<AB::Var> =
450452
next.specific[..LogupSpecificCols::<AB::Var>::width()].borrow();
451453

452-
self.memory_bridge
453-
.read(
454-
MemoryAddress::new(
455-
native_as,
456-
register_ptrs[0]
457-
+ AB::F::from_canonical_usize(EXT_DEG * 2)
458-
+ num_prod_spec
459-
+ (curr_logup_n - AB::F::ONE),
460-
), // curr_logup_n starts at 1.
461-
[max_round],
462-
start_timestamp,
463-
&logup_row_specific.read_records[0],
464-
)
465-
.eval(builder, logup_row);
466-
467454
// logup_row * within_round_limit =
468455
// logup_in_round_evaluation + logup_next_round_evaluation
469456
builder
@@ -488,8 +475,8 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
488475
.read(
489476
MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr),
490477
logup_row_specific.pq,
491-
start_timestamp + AB::F::ONE,
492-
&logup_row_specific.read_records[1],
478+
start_timestamp,
479+
&logup_row_specific.read_records[0],
493480
)
494481
.eval(builder, logup_row * within_round_limit);
495482

@@ -513,7 +500,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
513500
+ (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG),
514501
),
515502
logup_row_specific.p_evals,
516-
start_timestamp + AB::F::TWO,
503+
start_timestamp + AB::F::ONE,
517504
&logup_row_specific.write_records[0],
518505
)
519506
.eval(builder, logup_row * within_round_limit);
@@ -528,7 +515,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
528515
* AB::F::from_canonical_usize(EXT_DEG),
529516
),
530517
logup_row_specific.q_evals,
531-
start_timestamp + AB::F::from_canonical_usize(3),
518+
start_timestamp + AB::F::TWO,
532519
&logup_row_specific.write_records[1],
533520
)
534521
.eval(builder, logup_row * within_round_limit);

0 commit comments

Comments
 (0)