Skip to content

Commit 373367e

Browse files
authored
Revert "Optional Writeback in NativeSumcheck (#35)" (#38)
This reverts commit af3ccda.
1 parent 59844f0 commit 373367e

6 files changed

Lines changed: 175 additions & 132 deletions

File tree

extensions/native/circuit/cuda/include/native/sumcheck.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ template <typename T> struct NativeSumcheckCols {
8282

8383
T eval_acc[EXT_DEG];
8484

85-
T is_writeback;
85+
T is_hint_src_id;
8686

8787
T specific[COL_SPECIFIC_WIDTH];
8888
};

extensions/native/circuit/cuda/src/sumcheck.cu

Lines changed: 25 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -32,55 +32,34 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h
3232
);
3333
} else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) {
3434
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
35-
if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) {
36-
mem_fill_base(
37-
mem_helper,
38-
start_timestamp,
39-
specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base))
40-
);
41-
mem_fill_base(
42-
mem_helper,
43-
start_timestamp + 1,
44-
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
45-
);
46-
} else {
47-
mem_fill_base(
48-
mem_helper,
49-
start_timestamp,
50-
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
51-
);
52-
}
35+
mem_fill_base(
36+
mem_helper,
37+
start_timestamp,
38+
specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base))
39+
);
40+
mem_fill_base(
41+
mem_helper,
42+
start_timestamp + 1,
43+
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
44+
);
5345
}
5446
} else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) {
5547
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
56-
if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) {
57-
mem_fill_base(
58-
mem_helper,
59-
start_timestamp,
60-
specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base))
61-
);
62-
mem_fill_base(
63-
mem_helper,
64-
start_timestamp + 1,
65-
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
66-
);
67-
mem_fill_base(
68-
mem_helper,
69-
start_timestamp + 2,
70-
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
71-
);
72-
} else {
73-
mem_fill_base(
74-
mem_helper,
75-
start_timestamp,
76-
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
77-
);
78-
mem_fill_base(
79-
mem_helper,
80-
start_timestamp + 1,
81-
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
82-
);
83-
}
48+
mem_fill_base(
49+
mem_helper,
50+
start_timestamp,
51+
specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base))
52+
);
53+
mem_fill_base(
54+
mem_helper,
55+
start_timestamp + 1,
56+
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
57+
);
58+
mem_fill_base(
59+
mem_helper,
60+
start_timestamp + 2,
61+
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
62+
);
8463
}
8564
}
8665
}

extensions/native/circuit/src/sumcheck/air.rs

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::borrow::Borrow;
33
use openvm_circuit::{
44
arch::{ExecutionBridge, ExecutionState},
55
system::memory::{
6-
offline_checker::MemoryBridge,
6+
offline_checker::{MemoryBridge, MemoryReadAuxCols},
77
MemoryAddress,
88
},
99
};
@@ -26,6 +26,9 @@ use crate::{
2626
},
2727
};
2828

29+
pub const NUM_RWS_FOR_PRODUCT: usize = 2;
30+
pub const NUM_RWS_FOR_LOGUP: usize = 3;
31+
2932
#[derive(Clone, Debug)]
3033
pub struct NativeSumcheckAir {
3134
pub execution_bridge: ExecutionBridge,
@@ -102,7 +105,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
102105
within_round_limit,
103106
should_acc,
104107
eval_acc,
105-
is_writeback,
108+
is_hint_src_id,
106109
specific,
107110
} = local;
108111

@@ -232,6 +235,22 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
232235
next.start_timestamp,
233236
start_timestamp + AB::F::from_canonical_usize(8),
234237
);
238+
builder
239+
.when(prod_row)
240+
.when(next.prod_row + next.logup_row)
241+
.assert_eq(
242+
next.start_timestamp,
243+
start_timestamp
244+
+ within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT),
245+
);
246+
builder
247+
.when(logup_row)
248+
.when(next.prod_row + next.logup_row)
249+
.assert_eq(
250+
next.start_timestamp,
251+
start_timestamp
252+
+ within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP),
253+
);
235254

236255
// Termination condition
237256
assert_array_eq(
@@ -330,7 +349,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
330349
native_as,
331350
register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN),
332351
),
333-
[max_round, is_writeback],
352+
[max_round, is_hint_src_id],
334353
first_timestamp + AB::F::from_canonical_usize(7),
335354
&header_row_specific.read_records[7],
336355
)
@@ -373,6 +392,21 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
373392
);
374393
builder.assert_eq(prod_row * should_acc, prod_acc);
375394

395+
// Read p1, p2 from witness arrays
396+
self.memory_bridge
397+
.read(
398+
MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr),
399+
prod_row_specific.p,
400+
start_timestamp,
401+
&MemoryReadAuxCols {
402+
base: prod_row_specific.ps_record.base,
403+
},
404+
)
405+
.eval(
406+
builder,
407+
(prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id),
408+
);
409+
376410
// Obtain p1, p2 from hint space and write back to witness arrays
377411
self.memory_bridge
378412
.write(
@@ -383,7 +417,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
383417
)
384418
.eval(
385419
builder,
386-
(prod_in_round_evaluation + prod_next_round_evaluation) * is_writeback,
420+
(prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id,
387421
);
388422

389423
let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap();
@@ -398,7 +432,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
398432
register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG),
399433
),
400434
prod_row_specific.p_evals,
401-
start_timestamp + is_writeback * AB::F::ONE,
435+
start_timestamp + AB::F::ONE,
402436
&prod_row_specific.write_record,
403437
)
404438
.eval(builder, prod_row * within_round_limit);
@@ -465,6 +499,21 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
465499
);
466500
builder.assert_eq(logup_row * should_acc, logup_acc);
467501

502+
// Read p1, p2, q1, q2 from witness arrays
503+
self.memory_bridge
504+
.read(
505+
MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr),
506+
logup_row_specific.pq,
507+
start_timestamp,
508+
&MemoryReadAuxCols {
509+
base: logup_row_specific.pqs_record.base,
510+
},
511+
)
512+
.eval(
513+
builder,
514+
(logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id),
515+
);
516+
468517
// Obtain p1, p2, q1, q2 from hint space
469518
self.memory_bridge
470519
.write(
@@ -475,7 +524,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
475524
)
476525
.eval(
477526
builder,
478-
(logup_in_round_evaluation + logup_next_round_evaluation) * is_writeback,
527+
(logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id,
479528
);
480529
let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap();
481530
let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)]
@@ -497,7 +546,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
497546
+ (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG),
498547
),
499548
logup_row_specific.p_evals,
500-
start_timestamp + is_writeback * AB::F::ONE,
549+
start_timestamp + AB::F::ONE,
501550
&logup_row_specific.write_records[0],
502551
)
503552
.eval(builder, logup_row * within_round_limit);
@@ -512,7 +561,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
512561
* AB::F::from_canonical_usize(EXT_DEG),
513562
),
514563
logup_row_specific.q_evals,
515-
start_timestamp + is_writeback * AB::F::ONE + AB::F::ONE,
564+
start_timestamp + AB::F::TWO,
516565
&logup_row_specific.write_records[1],
517566
)
518567
.eval(builder, logup_row * within_round_limit);

0 commit comments

Comments
 (0)