Skip to content

Commit a695fb9

Browse files
authored
Fix: cuda type alignment for NativeSumcheck (#33)
* clippy * fix gpu bug
1 parent 8b4c69f commit a695fb9

6 files changed

Lines changed: 24 additions & 23 deletions

File tree

extensions/native/circuit/cuda/include/native/sumcheck.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ template <typename T> struct HeaderSpecificCols {
1717
template <typename T> struct ProdSpecificCols {
1818
T data_ptr;
1919
T p[EXT_DEG * 2];
20-
MemoryReadAuxCols<T> read_records[1];
2120
T p_evals[EXT_DEG];
2221
MemoryWriteAuxCols<T, EXT_DEG> write_record;
22+
MemoryWriteAuxCols<T, EXT_DEG * 2> ps_record;
2323
T eval_rlc[EXT_DEG];
2424
};
2525

2626
template <typename T> struct LogupSpecificCols {
2727
T data_ptr;
2828
T pq[EXT_DEG * 4];
29-
MemoryReadAuxCols<T> read_records[1];
3029
T p_evals[EXT_DEG];
3130
T q_evals[EXT_DEG];
31+
MemoryWriteAuxCols<T, EXT_DEG * 4> pqs_record;
3232
MemoryWriteAuxCols<T, EXT_DEG> write_records[2];
3333
T eval_rlc[EXT_DEG];
3434
};

extensions/native/circuit/cuda/src/sumcheck.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,18 @@
66

77
using namespace native;
88

9+
constexpr uint32_t header_read_records_len() {
10+
return sizeof(((HeaderSpecificCols<uint8_t> *)nullptr)->read_records)
11+
/ sizeof(MemoryReadAuxCols<uint8_t>);
12+
}
13+
914
__device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_helper) {
1015
RowSlice specific = row.slice_from(COL_INDEX(NativeSumcheckCols, specific));
1116
uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32();
1217

1318
if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) {
14-
for (uint32_t i = 0; i < 8; ++i) {
19+
constexpr uint32_t header_records = header_read_records_len();
20+
for (uint32_t i = 0; i < header_records; ++i) {
1521
mem_fill_base(
1622
mem_helper,
1723
start_timestamp + i,

extensions/native/circuit/src/sumcheck/chip.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ use openvm_circuit::{
1010
native_adapter::util::{memory_read_native, tracing_write_native_inplace},
1111
},
1212
};
13-
use openvm_instructions::{
14-
instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS,
15-
};
13+
use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
1614
use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL;
1715
use openvm_stark_backend::p3_field::PrimeField32;
1816

@@ -227,7 +225,8 @@ where
227225
let mut eval_acc = elem_to_ext(F::from_canonical_u32(0));
228226
let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1));
229227

230-
// all rows share same register values, ctx, challenges, max_round, hint_space_ptrs (optional)
228+
// all rows share same register values, ctx, challenges, max_round, hint_space_ptrs
229+
// (optional)
231230
for row in rows.iter_mut() {
232231
// c1, c2 are same during the entire execution
233232
row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]);

extensions/native/circuit/src/sumcheck/execution.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ impl NativeSumcheckExecutor {
3333
#[inline(always)]
3434
fn pre_compute_impl<F: PrimeField32>(
3535
&self,
36-
pc: u32,
36+
_pc: u32,
3737
inst: &Instruction<F>,
3838
data: &mut NativeSumcheckPreCompute,
3939
) -> Result<(), StaticProgramError> {

extensions/native/compiler/src/ir/sumcheck.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ impl<C: Config> Builder<C> {
2626
///
2727
/// 2. for computing expected eval of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) *
2828
/// p[i][1].
29+
#[allow(clippy::too_many_arguments)]
2930
pub fn sumcheck_layer_eval(
3031
&mut self,
3132
input_ctx: &Array<C, Usize<C::N>>, // Context variables

extensions/native/recursion/tests/sumcheck.rs

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::iter::{once, repeat_n};
1+
use std::iter::once;
22

33
use openvm_circuit::{arch::instructions::program::Program, utils::air_test_impl};
44
#[cfg(feature = "cuda")]
@@ -35,12 +35,10 @@ fn test_sumcheck_layer_eval_with_hint_ids() {
3535
let num_logup_specs = 8;
3636

3737
let prod_evals: Vec<E> = (0..(num_prod_specs * num_layers * 2))
38-
.into_iter()
3938
.map(|_| new_rand_ext(&mut rng))
4039
.collect();
4140

4241
let logup_evals: Vec<E> = (0..(num_logup_specs * num_layers * 4))
43-
.into_iter()
4442
.map(|_| new_rand_ext(&mut rng))
4543
.collect();
4644

@@ -73,13 +71,10 @@ fn test_sumcheck_layer_eval_with_hint_ids() {
7371
standard_fri_params_with_100_bits_conjectured_security(1)
7472
};
7573

76-
let mut input_stream: Vec<Vec<F>> = vec![];
77-
input_stream.push(
78-
prod_evals
79-
.into_iter()
80-
.flat_map(|e| <E as FieldExtensionAlgebra<F>>::as_base_slice(&e).to_vec())
81-
.collect(),
82-
);
74+
let mut input_stream: Vec<Vec<F>> = vec![prod_evals
75+
.into_iter()
76+
.flat_map(|e| <E as FieldExtensionAlgebra<F>>::as_base_slice(&e).to_vec())
77+
.collect()];
8378
input_stream.push(
8479
logup_evals
8580
.into_iter()
@@ -137,7 +132,7 @@ fn build_test_program<C: Config>(
137132
) {
138133
let mode = 1; // current_layer
139134

140-
let mut ctx_u32s = vec![
135+
let ctx_u32s = vec![
141136
round,
142137
num_prod_specs,
143138
num_logup_specs,
@@ -175,16 +170,16 @@ fn build_test_program<C: Config>(
175170

176171
let num_prod_evals = num_prod_specs * num_layers * 2;
177172
let prod_spec_evals: Array<C, Ext<C::F, C::EF>> = builder.dyn_array(num_prod_evals);
178-
for idx in 0..num_prod_evals {
179-
let e: Ext<C::F, C::EF> = builder.constant(prod_evals[idx]);
173+
for (idx, prod_eval) in prod_evals.into_iter().enumerate() {
174+
let e: Ext<C::F, C::EF> = builder.constant(prod_eval);
180175

181176
builder.set(&prod_spec_evals, idx, e);
182177
}
183178

184179
let num_logup_evals = num_logup_specs * num_layers * 4;
185180
let logup_spec_evals: Array<C, Ext<C::F, C::EF>> = builder.dyn_array(num_logup_evals);
186-
for idx in 0..num_logup_evals {
187-
let e: Ext<C::F, C::EF> = builder.constant(logup_evals[idx]);
181+
for (idx, logup_eval) in logup_evals.into_iter().enumerate() {
182+
let e: Ext<C::F, C::EF> = builder.constant(logup_eval);
188183

189184
builder.set(&logup_spec_evals, idx, e);
190185
}

0 commit comments

Comments
 (0)