Skip to content

Commit a8de5c3

Browse files
authored
Merge pull request #40 from scroll-tech/feat/hint_bridge_gpu
This PR adds CUDA/GPU support for the HintSpaceProvider periphery chip used by the new HintBridge, enabling GPU-side trace generation for the hint space provider table and registering the chip in the GPU prover extension.
2 parents 435b8ad + f2cd332 commit a8de5c3

5 files changed

Lines changed: 213 additions & 5 deletions

File tree

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include "launcher.cuh"
2+
#include "primitives/trace_access.h"
3+
4+
// Columns layout matches HintSpaceProviderCols<T> in hint_space_provider.rs
5+
// Fields: hint_id, offset, value, is_valid
6+
template <typename T> struct HintSpaceProviderCols {
7+
T hint_id;
8+
T offset;
9+
T value;
10+
T is_valid;
11+
};
12+
13+
constexpr uint32_t HINT_SPACE_PROVIDER_WIDTH = sizeof(HintSpaceProviderCols<uint8_t>);
14+
15+
__global__ void hint_space_provider_tracegen(
16+
Fp *trace,
17+
size_t height,
18+
const Fp *records,
19+
size_t rows_used
20+
) {
21+
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
22+
if (idx >= height) {
23+
return;
24+
}
25+
26+
RowSlice row(trace + idx, height);
27+
if (idx < rows_used) {
28+
// Each record is a triple (hint_id, offset, value)
29+
const Fp *rec = records + idx * 3;
30+
COL_WRITE_VALUE(row, HintSpaceProviderCols, hint_id, rec[0]);
31+
COL_WRITE_VALUE(row, HintSpaceProviderCols, offset, rec[1]);
32+
COL_WRITE_VALUE(row, HintSpaceProviderCols, value, rec[2]);
33+
COL_WRITE_VALUE(row, HintSpaceProviderCols, is_valid, Fp::one());
34+
} else {
35+
row.fill_zero(0, HINT_SPACE_PROVIDER_WIDTH);
36+
}
37+
}
38+
39+
extern "C" int _hint_space_provider_tracegen(
40+
Fp *d_trace,
41+
size_t height,
42+
size_t width,
43+
const Fp *d_records,
44+
size_t rows_used
45+
) {
46+
assert((height & (height - 1)) == 0);
47+
assert(width == HINT_SPACE_PROVIDER_WIDTH);
48+
auto [grid, block] = kernel_launch_params(height);
49+
hint_space_provider_tracegen<<<grid, block>>>(
50+
d_trace,
51+
height,
52+
d_records,
53+
rows_used
54+
);
55+
return CHECK_KERNEL();
56+
}

extensions/native/circuit/src/cuda_abi.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,33 @@ pub mod native_jal_rangecheck_cuda {
345345
))
346346
}
347347
}
348+
349+
pub mod hint_space_provider_cuda {
350+
use super::*;
351+
352+
extern "C" {
353+
pub fn _hint_space_provider_tracegen(
354+
d_trace: *mut F,
355+
height: usize,
356+
width: usize,
357+
d_records: *const F,
358+
rows_used: usize,
359+
) -> i32;
360+
}
361+
362+
pub unsafe fn tracegen(
363+
d_trace: &DeviceBuffer<F>,
364+
height: usize,
365+
width: usize,
366+
d_records: &DeviceBuffer<F>,
367+
rows_used: usize,
368+
) -> Result<(), CudaError> {
369+
CudaError::from_result(_hint_space_provider_tracegen(
370+
d_trace.as_mut_ptr(),
371+
height,
372+
width,
373+
d_records.as_ptr(),
374+
rows_used,
375+
))
376+
}
377+
}

extensions/native/circuit/src/extension/cuda.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use openvm_circuit::{
24
arch::{ChipInventory, ChipInventoryError, DenseRecordArena, VmProverExtension},
35
system::cuda::extensions::get_inventory_range_checker,
@@ -14,6 +16,7 @@ use crate::{
1416
field_arithmetic::{FieldArithmeticAir, FieldArithmeticChipGpu},
1517
field_extension::{FieldExtensionAir, FieldExtensionChipGpu},
1618
fri::{FriReducedOpeningAir, FriReducedOpeningChipGpu},
19+
hint_space_provider::{cuda::HintSpaceProviderChipGpu, HintSpaceProviderAir, HintSpaceProviderChip},
1720
jal_rangecheck::{JalRangeCheckAir, JalRangeCheckGpu},
1821
loadstore::{NativeLoadStoreAir, NativeLoadStoreChipGpu},
1922
poseidon2::{air::NativePoseidon2Air, NativePoseidon2ChipGpu},
@@ -76,8 +79,16 @@ impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Native>
7679
let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits);
7780
inventory.add_executor_chip(poseidon2);
7881

82+
// HintSpaceProvider must be registered BEFORE NativeSumcheck because chips are
83+
// dispatched in reverse order: sumcheck runs first and populates the provider.
84+
let hint_air: &HintSpaceProviderAir = inventory.next_air::<HintSpaceProviderAir>()?;
85+
let cpu_chip = Arc::new(HintSpaceProviderChip::new(hint_air.hint_bus));
86+
let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone());
87+
inventory.add_periphery_chip(provider_gpu);
88+
7989
inventory.next_air::<NativeSumcheckAir>()?;
80-
let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits);
90+
let sumcheck =
91+
NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits, cpu_chip);
8192
inventory.add_executor_chip(sumcheck);
8293

8394
Ok(())

extensions/native/circuit/src/hint_space_provider.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,56 @@ impl<F: PrimeField32> ChipUsageGetter for HintSpaceProviderChip<F> {
130130
NUM_HINT_SPACE_PROVIDER_COLS
131131
}
132132
}
133+
134+
#[cfg(feature = "cuda")]
135+
pub mod cuda {
136+
use std::sync::Arc;
137+
138+
use openvm_circuit::arch::DenseRecordArena;
139+
use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F};
140+
use openvm_cuda_common::copy::MemCopyH2D;
141+
use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
142+
143+
use super::{HintSpaceProviderChip, NUM_HINT_SPACE_PROVIDER_COLS};
144+
use crate::cuda_abi::hint_space_provider_cuda;
145+
146+
pub struct HintSpaceProviderChipGpu {
147+
pub cpu_chip: Arc<HintSpaceProviderChip<F>>,
148+
}
149+
150+
impl HintSpaceProviderChipGpu {
151+
pub fn new(cpu_chip: Arc<HintSpaceProviderChip<F>>) -> Self {
152+
Self { cpu_chip }
153+
}
154+
}
155+
156+
impl Chip<DenseRecordArena, GpuBackend> for HintSpaceProviderChipGpu {
157+
fn generate_proving_ctx(&self, _: DenseRecordArena) -> AirProvingContext<GpuBackend> {
158+
let data = std::mem::take(&mut *self.cpu_chip.data.lock().unwrap());
159+
let rows_used = data.len();
160+
let height = rows_used.next_power_of_two().max(2);
161+
162+
// Flatten (hint_id, offset, value) triples into a contiguous [F] buffer
163+
let flat: Vec<F> = data
164+
.into_iter()
165+
.flat_map(|(h, o, v)| [h, o, v])
166+
.collect();
167+
168+
let d_records = flat.to_device().unwrap();
169+
let trace = DeviceMatrix::<F>::with_capacity(height, NUM_HINT_SPACE_PROVIDER_COLS);
170+
171+
unsafe {
172+
hint_space_provider_cuda::tracegen(
173+
trace.buffer(),
174+
height,
175+
NUM_HINT_SPACE_PROVIDER_COLS,
176+
&d_records,
177+
rows_used,
178+
)
179+
.unwrap();
180+
}
181+
182+
AirProvingContext::simple_no_pis(trace)
183+
}
184+
}
185+
}

extensions/native/circuit/src/sumcheck/cuda.rs

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{mem::size_of, slice::from_raw_parts, sync::Arc};
1+
use std::{borrow::Borrow, mem::size_of, slice::from_raw_parts, sync::Arc};
22

33
use derive_new::new;
44
use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
@@ -7,15 +7,70 @@ use openvm_cuda_backend::{
77
base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
88
};
99
use openvm_cuda_common::copy::MemCopyH2D;
10-
use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
10+
use openvm_stark_backend::{p3_field::PrimeField32, prover::types::AirProvingContext, Chip};
1111

12-
use super::columns::NativeSumcheckCols;
13-
use crate::cuda_abi::sumcheck_cuda;
12+
use super::columns::{LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols};
13+
use crate::{
14+
cuda_abi::sumcheck_cuda,
15+
hint_space_provider::SharedHintSpaceProviderChip,
16+
};
1417

1518
#[derive(new)]
1619
pub struct NativeSumcheckChipGpu {
1720
pub range_checker: Arc<VariableRangeCheckerChipGPU>,
1821
pub timestamp_max_bits: usize,
22+
pub hint_space_provider: SharedHintSpaceProviderChip<F>,
23+
}
24+
25+
impl NativeSumcheckChipGpu {
26+
/// Scans execution records to populate the hint space provider with
27+
/// (hint_id, offset, value) triples for each hint element referenced
28+
/// by prod and logup rows. This bridges the gap between CPU execution
29+
/// (which produces the records) and GPU trace generation.
30+
fn populate_hint_provider(&self, records: &[u8]) {
31+
let width = NativeSumcheckCols::<F>::width();
32+
let record_size = width * size_of::<F>();
33+
if records.len() % record_size != 0 {
34+
return;
35+
}
36+
let num_rows = records.len() / record_size;
37+
38+
let row_slice = unsafe {
39+
let ptr = records.as_ptr() as *const F;
40+
from_raw_parts(ptr, num_rows * width)
41+
};
42+
43+
for i in 0..num_rows {
44+
let row_data = &row_slice[i * width..(i + 1) * width];
45+
let cols: &NativeSumcheckCols<F> = row_data.borrow();
46+
47+
if cols.within_round_limit != F::ONE {
48+
continue;
49+
}
50+
51+
if cols.prod_row == F::ONE {
52+
let prod_specific: &ProdSpecificCols<F> =
53+
cols.specific[..ProdSpecificCols::<F>::width()].borrow();
54+
for (j, &val) in prod_specific.p.iter().enumerate() {
55+
self.hint_space_provider.request(
56+
cols.prod_hint_id,
57+
prod_specific.data_ptr + F::from_canonical_usize(j),
58+
val,
59+
);
60+
}
61+
} else if cols.logup_row == F::ONE {
62+
let logup_specific: &LogupSpecificCols<F> =
63+
cols.specific[..LogupSpecificCols::<F>::width()].borrow();
64+
for (j, &val) in logup_specific.pq.iter().enumerate() {
65+
self.hint_space_provider.request(
66+
cols.logup_hint_id,
67+
logup_specific.data_ptr + F::from_canonical_usize(j),
68+
val,
69+
);
70+
}
71+
}
72+
}
73+
}
1974
}
2075

2176
impl Chip<DenseRecordArena, GpuBackend> for NativeSumcheckChipGpu {
@@ -25,6 +80,9 @@ impl Chip<DenseRecordArena, GpuBackend> for NativeSumcheckChipGpu {
2580
return get_empty_air_proving_ctx::<GpuBackend>();
2681
}
2782

83+
// Populate hint space provider from execution records before GPU upload.
84+
self.populate_hint_provider(records);
85+
2886
let width = NativeSumcheckCols::<F>::width();
2987
let record_size = width * size_of::<F>();
3088
assert_eq!(records.len() % record_size, 0);

0 commit comments

Comments
 (0)