Skip to content

Commit af3ccda

Browse files
darth-cyhero78119
andauthored
Optional Writeback in NativeSumcheck (#35)
* optional writeback * Correct timestamp * fix constraints * correct constraints * remove unnecessary declaration * remove change * misc: fix syntax error (#37) --------- Co-authored-by: Ming <hero78119@gmail.com>
1 parent 8ea50e3 commit af3ccda

6 files changed

Lines changed: 132 additions & 175 deletions

File tree

extensions/native/circuit/cuda/include/native/sumcheck.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ template <typename T> struct NativeSumcheckCols {
8282

8383
T eval_acc[EXT_DEG];
8484

85-
T is_hint_src_id;
85+
T is_writeback;
8686

8787
T specific[COL_SPECIFIC_WIDTH];
8888
};

extensions/native/circuit/cuda/src/sumcheck.cu

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,34 +32,55 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h
3232
);
3333
} else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) {
3434
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
35-
mem_fill_base(
36-
mem_helper,
37-
start_timestamp,
38-
specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base))
39-
);
40-
mem_fill_base(
41-
mem_helper,
42-
start_timestamp + 1,
43-
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
44-
);
35+
if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) {
36+
mem_fill_base(
37+
mem_helper,
38+
start_timestamp,
39+
specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base))
40+
);
41+
mem_fill_base(
42+
mem_helper,
43+
start_timestamp + 1,
44+
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
45+
);
46+
} else {
47+
mem_fill_base(
48+
mem_helper,
49+
start_timestamp,
50+
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
51+
);
52+
}
4553
}
4654
} else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) {
4755
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
48-
mem_fill_base(
49-
mem_helper,
50-
start_timestamp,
51-
specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base))
52-
);
53-
mem_fill_base(
54-
mem_helper,
55-
start_timestamp + 1,
56-
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
57-
);
58-
mem_fill_base(
59-
mem_helper,
60-
start_timestamp + 2,
61-
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
62-
);
56+
if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) {
57+
mem_fill_base(
58+
mem_helper,
59+
start_timestamp,
60+
specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base))
61+
);
62+
mem_fill_base(
63+
mem_helper,
64+
start_timestamp + 1,
65+
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
66+
);
67+
mem_fill_base(
68+
mem_helper,
69+
start_timestamp + 2,
70+
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
71+
);
72+
} else {
73+
mem_fill_base(
74+
mem_helper,
75+
start_timestamp,
76+
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
77+
);
78+
mem_fill_base(
79+
mem_helper,
80+
start_timestamp + 1,
81+
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
82+
);
83+
}
6384
}
6485
}
6586
}

extensions/native/circuit/src/sumcheck/air.rs

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::borrow::Borrow;
33
use openvm_circuit::{
44
arch::{ExecutionBridge, ExecutionState},
55
system::memory::{
6-
offline_checker::{MemoryBridge, MemoryReadAuxCols},
6+
offline_checker::MemoryBridge,
77
MemoryAddress,
88
},
99
};
@@ -26,9 +26,6 @@ use crate::{
2626
},
2727
};
2828

29-
pub const NUM_RWS_FOR_PRODUCT: usize = 2;
30-
pub const NUM_RWS_FOR_LOGUP: usize = 3;
31-
3229
#[derive(Clone, Debug)]
3330
pub struct NativeSumcheckAir {
3431
pub execution_bridge: ExecutionBridge,
@@ -105,7 +102,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
105102
within_round_limit,
106103
should_acc,
107104
eval_acc,
108-
is_hint_src_id,
105+
is_writeback,
109106
specific,
110107
} = local;
111108

@@ -235,22 +232,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
235232
next.start_timestamp,
236233
start_timestamp + AB::F::from_canonical_usize(8),
237234
);
238-
builder
239-
.when(prod_row)
240-
.when(next.prod_row + next.logup_row)
241-
.assert_eq(
242-
next.start_timestamp,
243-
start_timestamp
244-
+ within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT),
245-
);
246-
builder
247-
.when(logup_row)
248-
.when(next.prod_row + next.logup_row)
249-
.assert_eq(
250-
next.start_timestamp,
251-
start_timestamp
252-
+ within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP),
253-
);
254235

255236
// Termination condition
256237
assert_array_eq(
@@ -349,7 +330,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
349330
native_as,
350331
register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN),
351332
),
352-
[max_round, is_hint_src_id],
333+
[max_round, is_writeback],
353334
first_timestamp + AB::F::from_canonical_usize(7),
354335
&header_row_specific.read_records[7],
355336
)
@@ -392,21 +373,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
392373
);
393374
builder.assert_eq(prod_row * should_acc, prod_acc);
394375

395-
// Read p1, p2 from witness arrays
396-
self.memory_bridge
397-
.read(
398-
MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr),
399-
prod_row_specific.p,
400-
start_timestamp,
401-
&MemoryReadAuxCols {
402-
base: prod_row_specific.ps_record.base,
403-
},
404-
)
405-
.eval(
406-
builder,
407-
(prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id),
408-
);
409-
410376
// Obtain p1, p2 from hint space and write back to witness arrays
411377
self.memory_bridge
412378
.write(
@@ -417,7 +383,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
417383
)
418384
.eval(
419385
builder,
420-
(prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id,
386+
(prod_in_round_evaluation + prod_next_round_evaluation) * is_writeback,
421387
);
422388

423389
let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap();
@@ -432,7 +398,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
432398
register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG),
433399
),
434400
prod_row_specific.p_evals,
435-
start_timestamp + AB::F::ONE,
401+
start_timestamp + is_writeback * AB::F::ONE,
436402
&prod_row_specific.write_record,
437403
)
438404
.eval(builder, prod_row * within_round_limit);
@@ -499,21 +465,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
499465
);
500466
builder.assert_eq(logup_row * should_acc, logup_acc);
501467

502-
// Read p1, p2, q1, q2 from witness arrays
503-
self.memory_bridge
504-
.read(
505-
MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr),
506-
logup_row_specific.pq,
507-
start_timestamp,
508-
&MemoryReadAuxCols {
509-
base: logup_row_specific.pqs_record.base,
510-
},
511-
)
512-
.eval(
513-
builder,
514-
(logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id),
515-
);
516-
517468
// Obtain p1, p2, q1, q2 from hint space
518469
self.memory_bridge
519470
.write(
@@ -524,7 +475,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
524475
)
525476
.eval(
526477
builder,
527-
(logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id,
478+
(logup_in_round_evaluation + logup_next_round_evaluation) * is_writeback,
528479
);
529480
let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap();
530481
let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)]
@@ -546,7 +497,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
546497
+ (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG),
547498
),
548499
logup_row_specific.p_evals,
549-
start_timestamp + AB::F::ONE,
500+
start_timestamp + is_writeback * AB::F::ONE,
550501
&logup_row_specific.write_records[0],
551502
)
552503
.eval(builder, logup_row * within_round_limit);
@@ -561,7 +512,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
561512
* AB::F::from_canonical_usize(EXT_DEG),
562513
),
563514
logup_row_specific.q_evals,
564-
start_timestamp + AB::F::TWO,
515+
start_timestamp + is_writeback * AB::F::ONE + AB::F::ONE,
565516
&logup_row_specific.write_records[1],
566517
)
567518
.eval(builder, logup_row * within_round_limit);

0 commit comments

Comments
 (0)