Skip to content

Commit f2cd332

Browse files
committed
resolve native sumcheck gpu trace fill problem
1 parent dcd6bfc commit f2cd332

2 files changed

Lines changed: 66 additions & 7 deletions

File tree

extensions/native/circuit/src/extension/cuda.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,13 @@ impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Native>
8282
// HintSpaceProvider must be registered BEFORE NativeSumcheck because chips are
8383
// dispatched in reverse order: sumcheck runs first and populates the provider.
8484
let hint_air: &HintSpaceProviderAir = inventory.next_air::<HintSpaceProviderAir>()?;
85-
let cpu_chip = Arc::new(HintSpaceProviderChip::new(hint_air.hint_bus.clone()));
86-
let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip);
85+
let cpu_chip = Arc::new(HintSpaceProviderChip::new(hint_air.hint_bus));
86+
let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone());
8787
inventory.add_periphery_chip(provider_gpu);
8888

8989
inventory.next_air::<NativeSumcheckAir>()?;
90-
let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits);
90+
let sumcheck =
91+
NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits, cpu_chip);
9192
inventory.add_executor_chip(sumcheck);
9293

9394
Ok(())

extensions/native/circuit/src/sumcheck/cuda.rs

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{mem::size_of, slice::from_raw_parts, sync::Arc};
1+
use std::{borrow::Borrow, mem::size_of, slice::from_raw_parts, sync::Arc};
22

33
use derive_new::new;
44
use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
@@ -7,15 +7,70 @@ use openvm_cuda_backend::{
77
base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
88
};
99
use openvm_cuda_common::copy::MemCopyH2D;
10-
use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
10+
use openvm_stark_backend::{p3_field::PrimeField32, prover::types::AirProvingContext, Chip};
1111

12-
use super::columns::NativeSumcheckCols;
13-
use crate::cuda_abi::sumcheck_cuda;
12+
use super::columns::{LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols};
13+
use crate::{
14+
cuda_abi::sumcheck_cuda,
15+
hint_space_provider::SharedHintSpaceProviderChip,
16+
};
1417

1518
#[derive(new)]
1619
pub struct NativeSumcheckChipGpu {
1720
pub range_checker: Arc<VariableRangeCheckerChipGPU>,
1821
pub timestamp_max_bits: usize,
22+
pub hint_space_provider: SharedHintSpaceProviderChip<F>,
23+
}
24+
25+
impl NativeSumcheckChipGpu {
26+
/// Scans execution records to populate the hint space provider with
27+
/// (hint_id, offset, value) triples for each hint element referenced
28+
/// by prod and logup rows. This bridges the gap between CPU execution
29+
/// (which produces the records) and GPU trace generation.
30+
fn populate_hint_provider(&self, records: &[u8]) {
31+
let width = NativeSumcheckCols::<F>::width();
32+
let record_size = width * size_of::<F>();
33+
if records.len() % record_size != 0 {
34+
return;
35+
}
36+
let num_rows = records.len() / record_size;
37+
38+
let row_slice = unsafe {
39+
let ptr = records.as_ptr() as *const F;
40+
from_raw_parts(ptr, num_rows * width)
41+
};
42+
43+
for i in 0..num_rows {
44+
let row_data = &row_slice[i * width..(i + 1) * width];
45+
let cols: &NativeSumcheckCols<F> = row_data.borrow();
46+
47+
if cols.within_round_limit != F::ONE {
48+
continue;
49+
}
50+
51+
if cols.prod_row == F::ONE {
52+
let prod_specific: &ProdSpecificCols<F> =
53+
cols.specific[..ProdSpecificCols::<F>::width()].borrow();
54+
for (j, &val) in prod_specific.p.iter().enumerate() {
55+
self.hint_space_provider.request(
56+
cols.prod_hint_id,
57+
prod_specific.data_ptr + F::from_canonical_usize(j),
58+
val,
59+
);
60+
}
61+
} else if cols.logup_row == F::ONE {
62+
let logup_specific: &LogupSpecificCols<F> =
63+
cols.specific[..LogupSpecificCols::<F>::width()].borrow();
64+
for (j, &val) in logup_specific.pq.iter().enumerate() {
65+
self.hint_space_provider.request(
66+
cols.logup_hint_id,
67+
logup_specific.data_ptr + F::from_canonical_usize(j),
68+
val,
69+
);
70+
}
71+
}
72+
}
73+
}
1974
}
2075

2176
impl Chip<DenseRecordArena, GpuBackend> for NativeSumcheckChipGpu {
@@ -25,6 +80,9 @@ impl Chip<DenseRecordArena, GpuBackend> for NativeSumcheckChipGpu {
2580
return get_empty_air_proving_ctx::<GpuBackend>();
2681
}
2782

83+
// Populate hint space provider from execution records before GPU upload.
84+
self.populate_hint_provider(records);
85+
2886
let width = NativeSumcheckCols::<F>::width();
2987
let record_size = width * size_of::<F>();
3088
assert_eq!(records.len() % record_size, 0);

0 commit comments

Comments
 (0)