diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 38d211e58..76ac06c0b 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -17,8 +17,8 @@ use crate::{ ZKVMWitnesses, }, tables::{ - MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig, ShardRamCircuit, - TableCircuit, + MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig, + ShardRamEcTreeCircuit, TableCircuit, }, }; use ceno_emul::{ @@ -1682,20 +1682,21 @@ pub fn generate_witness<'a, E: ExtensionField>( ) }).unwrap(); - if let Some(shard_ram_witnesses) = - zkvm_witness.get_witness(&ShardRamCircuit::::name()) + if let Some(shard_ram_ec_tree_witnesses) = + zkvm_witness.get_witness(&ShardRamEcTreeCircuit::::name()) { info_span!("shard_ram_ec_sum").in_scope(|| { - let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses - .iter() - .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) - .map(|shard_ram_witness| { - ShardRamCircuit::::extract_ec_sum( - &system_config.mmu_config.ram_bus_circuit, - &shard_ram_witness.witness_rmms[0], - ) - }) - .sum(); + let shard_ram_ec_sum: SepticPoint = + shard_ram_ec_tree_witnesses + .iter() + .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) + .map(|shard_ram_witness| { + ShardRamEcTreeCircuit::::extract_ec_sum( + &system_config.mmu_config.ram_bus_ec_tree_circuit, + &shard_ram_witness.witness_rmms[0], + ) + }) + .sum(); let xy = shard_ram_ec_sum .x diff --git a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs index 12dcbbea1..c5acd5d9d 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs @@ -7,7 +7,10 @@ use rustc_hash::FxHashSet; use crate::{ e2e::ShardContext, error::ZKVMError, - tables::{MemFinalRecord, ShardRamConfig, ShardRamRecord, Y6_LO_TOP_BYTE_LT_BOUND}, + tables::{ + MemFinalRecord, ShardRamConfig, ShardRamEcTreeConfig, ShardRamRecord, + Y6_LO_TOP_BYTE_LT_BOUND, + }, }; /// Filter and construct a cross-shard ShardRamRecord without EC computation. @@ -67,11 +70,9 @@ pub fn extract_shard_ram_column_map( let mut x = [0u32; 7]; let mut y = [0u32; 7]; - let mut slope = [0u32; 7]; for i in 0..7 { x[i] = config.x[i].id as u32; y[i] = config.y[i].id as u32; - slope[i] = config.slope[i].id as u32; } // Poseidon2 columns: p3_cols are contiguous, followed by post_linear_layer_cols @@ -111,7 +112,7 @@ pub fn extract_shard_ram_column_map( is_global_write, x, y, - slope, + slope: [0; 7], poseidon2_base_col, num_poseidon2_cols, num_p3_cols, @@ -119,12 +120,43 @@ pub fn extract_shard_ram_column_map( } } +pub fn extract_shard_ram_ec_tree_column_map( + config: &ShardRamEcTreeConfig, + num_witin: usize, +) -> ShardRamColumnMap { + let mut x = [0u32; 7]; + let mut y = [0u32; 7]; + let mut slope = [0u32; 7]; + for i in 0..7 { + x[i] = config.x[i].id as u32; + y[i] = config.y[i].id as u32; + slope[i] = config.slope[i].id as u32; + } + + ShardRamColumnMap { + addr: 0, + is_ram_register: 0, + value: [0; 2], + shard: 0, + global_clk: 0, + local_clk: 0, + nonce: 0, + is_global_write: 0, + x, + y, + slope, + poseidon2_base_col: 0, + num_poseidon2_cols: 0, + num_p3_cols: 0, + num_cols: num_witin as u32, + } +} + // --------------------------------------------------------------------------- // ShardRam EC batch computation // --------------------------------------------------------------------------- use ceno_gpu::common::witgen::types::GpuShardRamRecord; -use p3::field::FieldAlgebra; use tracing::info_span; /// Convert a ShardRamRecord to GpuShardRamRecord (metadata only, EC fields zeroed). @@ -211,7 +243,7 @@ pub(crate) fn try_gpu_assign_shard_ram( ) -> Result>, ZKVMError> { use crate::scheme::constants::SEPTIC_EXTENSION_DEGREE; use ceno_gpu::{ - Buffer, CudaHal, + Buffer, bb31::CudaHalBB31, common::{transpose::matrix_transpose, witgen::types::GpuShardRamRecord}, }; @@ -242,8 +274,7 @@ pub(crate) fn try_gpu_assign_shard_ram( .take_while(|s| s.record.is_to_write_set) .count(); - let n = next_pow2_instance_padding(steps.len()); - let num_rows_padded = 2 * n; + let num_rows_padded = next_pow2_instance_padding(steps.len()); // 1. Convert ShardRamInput → GpuShardRamRecord let gpu_records: Vec = @@ -303,67 +334,7 @@ pub(crate) fn try_gpu_assign_shard_ram( }) })?; - // 4. GPU Phase 2: EC binary tree - let witness_buf = - tracing::info_span!("gpu_shard_ram_ec_tree", n).in_scope(|| -> Result<_, ZKVMError> { - let col_offsets = col_map.to_flat(); - let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) - })?; - - let mut init_x = vec![BB::ZERO; n * 7]; - let mut init_y = vec![BB::ZERO; n * 7]; - for (i, step) in steps.iter().enumerate() { - for j in 0..7 { - init_x[i * 7 + j] = unsafe { - *(&step.ec_point.point.x.0[j] as *const E::BaseField as *const BB) - }; - init_y[i * 7 + j] = unsafe { - *(&step.ec_point.point.y.0[j] as *const E::BaseField as *const BB) - }; - } - } - - let mut cur_x = hal.alloc_elems_from_host(&init_x, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc init_x failed: {e}").into()) - })?; - let mut cur_y = hal.alloc_elems_from_host(&init_y, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc init_y failed: {e}").into()) - })?; - - let mut witness_buf = gpu_witness.device_buffer; - let mut offset = num_rows_padded / 2; - let mut current_layer_len = n; - - loop { - if current_layer_len <= 1 { - break; - } - - let (next_x, next_y) = hal - .witgen - .shard_ram_ec_tree_layer( - &gpu_cols, - &cur_x, - &cur_y, - &mut witness_buf, - current_layer_len, - offset, - num_rows_padded, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU EC tree layer failed: {e}").into()) - })?; - - current_layer_len /= 2; - offset += current_layer_len; - cur_x = next_x; - cur_y = next_y; - } - - Ok(witness_buf) - })?; + let witness_buf = gpu_witness.device_buffer; // 5. Structural witness: keep device-resident only when cache policy keeps device backing. // In debug mode or cache-none mode, do transpose + D2H. @@ -531,7 +502,7 @@ pub(crate) fn try_gpu_assign_shard_ram_from_device( num_records: usize, num_local_writes: usize, ) -> Result>, ZKVMError> { - use ceno_gpu::{Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31, common::transpose::matrix_transpose}; use gkr_iop::gpu::gpu_prover::get_cuda_hal; use witness::{DeviceMatrixLayout, InstancePaddingStrategy, next_pow2_instance_padding}; @@ -546,8 +517,7 @@ pub(crate) fn try_gpu_assign_shard_ram_from_device( Err(_) => return Ok(None), }; - let n = next_pow2_instance_padding(num_records); - let num_rows_padded = 2 * n; + let num_rows_padded = next_pow2_instance_padding(num_records); let col_map = extract_shard_ram_column_map(config, num_witin); @@ -577,55 +547,7 @@ pub(crate) fn try_gpu_assign_shard_ram_from_device( }) })?; - // GPU: extract EC points from device records - let witness_buf = tracing::info_span!("gpu_shard_ram_ec_tree_from_device", n).in_scope( - || -> Result<_, ZKVMError> { - let col_offsets = col_map.to_flat(); - let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) - })?; - - let (mut cur_x, mut cur_y) = hal - .witgen - .extract_ec_points_from_device(device_records, num_records, n, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU extract_ec_points failed: {e}").into()) - })?; - - let mut witness_buf = gpu_witness.device_buffer; - let mut offset = num_rows_padded / 2; - let mut current_layer_len = n; - - loop { - if current_layer_len <= 1 { - break; - } - - let (next_x, next_y) = hal - .witgen - .shard_ram_ec_tree_layer( - &gpu_cols, - &cur_x, - &cur_y, - &mut witness_buf, - current_layer_len, - offset, - num_rows_padded, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU EC tree layer failed: {e}").into()) - })?; - - current_layer_len /= 2; - offset += current_layer_len; - cur_x = next_x; - cur_y = next_y; - } - - Ok(witness_buf) - }, - )?; + let witness_buf = gpu_witness.device_buffer; // Structural witness: keep device-resident only when cache policy keeps device backing. // In debug mode or cache-none mode, do transpose + D2H. @@ -770,6 +692,220 @@ pub(crate) fn try_gpu_assign_shard_ram_from_device( Ok(Some([raw_witin, raw_structural_witin])) } +pub(crate) fn try_gpu_assign_shard_ram_ec_tree_from_device( + config: &ShardRamEcTreeConfig, + num_witin: usize, + num_structural_witin: usize, + device_records: &ceno_gpu::common::buffer::BufferImpl<'static, u32>, + num_records: usize, + num_write_records: usize, +) -> Result>, ZKVMError> { + use ceno_gpu::{Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose}; + use gkr_iop::gpu::gpu_prover::get_cuda_hal; + use witness::{DeviceMatrixLayout, InstancePaddingStrategy, next_pow2_instance_padding}; + + type BB = ::BaseField; + + if std::any::TypeId::of::() != std::any::TypeId::of::() { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(h) => h, + Err(_) => return Ok(None), + }; + + if num_records == 0 { + return Ok(Some([ + witness::RowMajorMatrix::empty(), + witness::RowMajorMatrix::empty(), + ])); + } + + let n = next_pow2_instance_padding(num_records); + let num_rows_padded = 2 * n; + let col_map = extract_shard_ram_ec_tree_column_map(config, num_witin); + + let (mut gpu_witness, gpu_structural, mut cur_x, mut cur_y) = tracing::info_span!( + "gpu_shard_ram_ec_tree_per_row_from_device", + n = num_records, + num_write_records, + num_rows_padded, + num_witin, + ) + .in_scope(|| { + hal.witgen + .witgen_shard_ram_ec_tree_per_row_from_device( + &col_map, + device_records, + num_records, + num_write_records, + num_witin as u32, + num_structural_witin as u32, + num_rows_padded as u32, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU shard_ram EC tree per-row kernel failed: {e:?}").into(), + ) + }) + })?; + + tracing::info_span!("gpu_shard_ram_ec_tree_layers_from_device", n).in_scope( + || -> Result<(), ZKVMError> { + let col_offsets = col_map.to_flat(); + let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) + })?; + + let mut offset = n; + let mut current_layer_len = n; + while current_layer_len > 1 { + let (next_x, next_y) = hal + .witgen + .shard_ram_ec_tree_layer( + &gpu_cols, + &cur_x, + &cur_y, + &mut gpu_witness.device_buffer, + current_layer_len, + offset, + num_rows_padded, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU EC tree layer failed: {e}").into()) + })?; + + current_layer_len /= 2; + offset += current_layer_len; + cur_x = next_x; + cur_y = next_y; + } + + Ok(()) + }, + )?; + + let raw_structural_witin = if crate::instructions::gpu::config::is_debug_compare_enabled() + || !crate::instructions::gpu::config::should_materialize_witness_on_gpu() + { + let struct_data = tracing::info_span!( + "gpu_shard_ram_ec_tree_structural_transpose_d2h_from_device", + rows = gpu_structural.num_rows, + num_structural_witin, + ) + .in_scope(|| -> Result<_, ZKVMError> { + let mut struct_rmm_buf = hal + .witgen + .alloc_elems_on_device(num_rows_padded * num_structural_witin, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for EC tree struct transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut struct_rmm_buf, + &gpu_structural.device_buffer, + num_rows_padded, + num_structural_witin, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU EC tree struct transpose failed: {e}").into(), + ) + })?; + + let gpu_struct_data: Vec = struct_rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H EC tree struct failed: {e}").into()) + })?; + let out: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_struct_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + Ok(out) + })?; + witness::RowMajorMatrix::new_by_values( + struct_data, + num_structural_witin, + InstancePaddingStrategy::Default, + ) + } else { + let mut rmm = witness::RowMajorMatrix::new( + num_rows_padded, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + rmm.set_device_backing(gpu_structural.device_buffer, DeviceMatrixLayout::ColMajor); + rmm + }; + + let raw_witin = if crate::instructions::gpu::config::is_debug_compare_enabled() + || !crate::instructions::gpu::config::should_materialize_witness_on_gpu() + { + tracing::info_span!( + "gpu_shard_ram_ec_tree_witness_transpose_d2h_from_device", + num_rows_padded, + num_witin, + ) + .in_scope(|| -> Result<_, ZKVMError> { + let mut rmm_buf = hal + .witgen + .alloc_elems_on_device(num_rows_padded * num_witin, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for EC tree witness transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buf, + &gpu_witness.device_buffer, + num_rows_padded, + num_witin, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU EC tree witness transpose failed: {e}").into(), + ) + })?; + + let gpu_wit_data: Vec = rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H EC tree witness failed: {e}").into()) + })?; + let wit_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_wit_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + Ok(witness::RowMajorMatrix::new_by_values( + wit_data, + num_witin, + InstancePaddingStrategy::Default, + )) + })? + } else { + let mut rmm = witness::RowMajorMatrix::new( + num_rows_padded, + num_witin, + InstancePaddingStrategy::Default, + ); + rmm.set_device_backing(gpu_witness.device_buffer, DeviceMatrixLayout::ColMajor); + rmm + }; + + Ok(Some([raw_witin, raw_structural_witin])) +} + /// Full GPU pipeline for assign_shared_circuit: device-resident EC merge + partition + assign. /// Returns `Ok(None)` if GPU is unavailable, `Ok(Some((inputs, lk_mlt)))` on /// success — `lk_mlt` carries the y6_lo byte / LTU lookup multiplicity that @@ -784,10 +920,14 @@ pub(crate) fn try_gpu_assign_shared_circuit( &[crate::tables::MemFinalRecord], )], config: &ShardRamConfig, + ec_tree_config: &crate::tables::ShardRamEcTreeConfig, num_witin: usize, num_structural_witin: usize, + ec_tree_num_witin: usize, + ec_tree_num_structural_witin: usize, ) -> Result< Option<( + Vec>, Vec>, gkr_iop::utils::lk_multiplicity::Multiplicity, )>, @@ -795,11 +935,13 @@ pub(crate) fn try_gpu_assign_shared_circuit( > { use crate::{ instructions::gpu::{ - chips::shard_ram::gpu_batch_continuation_ec_on_device, + chips::shard_ram::{ + gpu_batch_continuation_ec_on_device, try_gpu_assign_shard_ram_ec_tree_from_device, + }, dispatch::take_shared_device_buffers, }, structs::{ChipInput, ZKVMWitnesses}, - tables::{ShardRamCircuit, ShardRamRecord, TableCircuit}, + tables::{ShardRamCircuit, ShardRamEcTreeCircuit, ShardRamRecord, TableCircuit}, witness::LkMultiplicity, }; use ceno_gpu::Buffer; @@ -984,6 +1126,17 @@ pub(crate) fn try_gpu_assign_shared_circuit( const IS_TO_WRITE_SET_U32_OFFSET: usize = 10; const POINT_Y6_U32_OFFSET: usize = 25; + let host_data: Vec = if total_records == 0 { + vec![] + } else { + partitioned_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("[GPU full pipeline] partitioned_buf D2H: {e}").into(), + ) + })? + }; + debug_assert_eq!(host_data.len(), total_records * record_u32s); + // 6.5. Derive ShardRam's per-row y6_lo byte / LTU lookup multiplicity // from the partitioned device buffer. Mirrors the per-row CPU push in // `ShardRamCircuit::assign_instance`; the constraint these queries serve @@ -993,12 +1146,6 @@ pub(crate) fn try_gpu_assign_shared_circuit( if total_records == 0 { return Ok(gkr_iop::utils::lk_multiplicity::Multiplicity::default()); } - let host_data: Vec = partitioned_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness( - format!("[GPU full pipeline] partitioned_buf D2H: {e}").into(), - ) - })?; - debug_assert_eq!(host_data.len(), total_records * record_u32s); let prime = ::MODULUS_U64; let lk_multiplicity = LkMultiplicity::default(); host_data.par_chunks_exact(record_u32s).for_each(|rec| { @@ -1021,10 +1168,10 @@ pub(crate) fn try_gpu_assign_shared_circuit( // 7. GPU assign_instances from device buffer. The proof format stores one // chip proof per circuit, so shard RAM must stay in one witness entry. - let circuit_inputs = + let (circuit_inputs, ec_tree_circuit_inputs) = info_span!("shard_ram_assign_from_device", n = total_records).in_scope(|| { if total_records == 0 { - return Ok::>, ZKVMError>(vec![]); + return Ok::<(Vec>, Vec>), ZKVMError>((vec![], vec![])); } let witness = ShardRamCircuit::::try_gpu_assign_instances_from_device( @@ -1040,11 +1187,31 @@ pub(crate) fn try_gpu_assign_shared_circuit( ZKVMError::InvalidWitness("GPU shard_ram from_device returned None".into()) })?; - Ok::<_, ZKVMError>(vec![ChipInput::new( + let num_reads = total_records - num_writes; + let circuit_inputs = vec![ChipInput::new( ShardRamCircuit::::name(), witness, - [num_writes, total_records - num_writes], - )]) + [num_writes, num_reads], + )]; + + let ec_tree_witness = try_gpu_assign_shard_ram_ec_tree_from_device( + ec_tree_config, + ec_tree_num_witin, + ec_tree_num_structural_witin, + &partitioned_buf, + total_records, + num_writes, + )? + .ok_or_else(|| { + ZKVMError::InvalidWitness("GPU shard_ram EC tree from_device returned None".into()) + })?; + let ec_tree_circuit_inputs = vec![ChipInput::new( + ShardRamEcTreeCircuit::::name(), + ec_tree_witness, + [num_reads, num_writes], + )]; + + Ok::<_, ZKVMError>((circuit_inputs, ec_tree_circuit_inputs)) })?; tracing::info!( @@ -1052,7 +1219,7 @@ pub(crate) fn try_gpu_assign_shared_circuit( total_records, ); - Ok(Some((circuit_inputs, lk_mlt))) + Ok(Some((circuit_inputs, ec_tree_circuit_inputs, lk_mlt))) } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index ca782c1f7..934a9d0f9 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -6,8 +6,8 @@ use crate::{ tables::{ DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsInitCircuit, HintsTable, LocalFinalCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, RegTable, - RegTableInitCircuit, ShardRamCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, - StaticMemTable, TableCircuit, + RegTableInitCircuit, ShardRamCircuit, ShardRamEcTreeCircuit, StackInitCircuit, StackTable, + StaticMemInitCircuit, StaticMemTable, TableCircuit, }, }; use ceno_emul::{Addr, IterAddresses, WORD_SIZE, Word}; @@ -30,6 +30,8 @@ pub struct MmuConfig { pub local_final_circuit: as TableCircuit>::TableConfig, /// ram bus to deal with cross shard read/write pub ram_bus_circuit: as TableCircuit>::TableConfig, + /// EC accumulation tree for cross-shard read/write points. + pub ram_bus_ec_tree_circuit: as TableCircuit>::TableConfig, pub params: ProgramParams, } @@ -44,6 +46,7 @@ impl MmuConfig { let heap_init_config = cs.register_table_circuit::>(); let local_final_circuit = cs.register_table_circuit::>(); let ram_bus_circuit = cs.register_table_circuit::>(); + let ram_bus_ec_tree_circuit = cs.register_table_circuit::>(); Self { reg_init_config, @@ -53,6 +56,7 @@ impl MmuConfig { heap_init_config, local_final_circuit, ram_bus_circuit, + ram_bus_ec_tree_circuit, params: cs.params.clone(), } } @@ -205,6 +209,7 @@ impl MmuConfig { cs, &(shard_ctx, all_records.as_slice()), &self.ram_bus_circuit, + &self.ram_bus_ec_tree_circuit, ) })?; Ok(()) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 45fbda851..d440af2cc 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -13,8 +13,8 @@ use crate::{ utils::{ GkrOutputStageMask, assign_group_evals, derive_ecc_bridge_claims, extract_ecc_quark_witness_inputs, first_layer_output_group_stage_masks, - infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, - split_rotation_evals, + first_layer_selector_contexts, infer_tower_logup_witness, infer_tower_product_witness, + interleaving_mles_to_mles, split_rotation_evals, }, verifier::eval_batched_main_frontload_terms, }, @@ -891,7 +891,6 @@ impl> MainSumcheckProver> MainSumcheckProver> let mut max_degree = 0usize; for job in &jobs { - let ComposedConstrainSystem { - zkvm_v1_css: cs, - gkr_circuit, - } = job.cs; + let ComposedConstrainSystem { gkr_circuit, .. } = job.cs; - let num_instances = job.input.num_instances(); let log2_num_instances = job.input.log2_num_instances(); let num_var_with_rotation = log2_num_instances + job.cs.rotation_vars().unwrap_or(0); max_num_variables = max_num_variables.max(num_var_with_rotation); @@ -1131,40 +1100,12 @@ impl> }; let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); max_degree = max_degree.max(first_layer.max_expr_degree + 1); - let group_stage_masks = first_layer_output_group_stage_masks(job.cs, gkr_circuit); - let selector_ctxs = first_layer - .out_sel_and_eval_exprs - .iter() - .zip_eq(group_stage_masks.iter()) - .map(|((selector, _), stage_mask)| { - if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() - { - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - } - } else if cs.r_selector.as_ref() == Some(selector) { - SelectorContext { - offset: 0, - num_instances: job.input.num_instances[0], - num_vars: num_var_with_rotation, - } - } else if cs.w_selector.as_ref() == Some(selector) { - SelectorContext { - offset: job.input.num_instances[0], - num_instances: job.input.num_instances[1], - num_vars: num_var_with_rotation, - } - } else { - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - } - } - }) - .collect_vec(); + let selector_ctxs = first_layer_selector_contexts( + job.cs, + gkr_circuit, + job.input.num_instances, + num_var_with_rotation, + ); let mut out_evals = vec![PointAndEval::new(job.rt_tower.clone(), E::ZERO); gkr_circuit.n_evaluations]; diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index cbbc40ca6..75af4737d 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -15,7 +15,7 @@ use crate::{ utils::{ GkrOutputStageMask, assign_group_evals, derive_ecc_bridge_claims, extract_ecc_quark_witness_inputs, first_layer_output_group_stage_masks, - split_rotation_evals, + first_layer_selector_contexts, split_rotation_evals, }, verifier::eval_batched_main_frontload_terms, }, @@ -513,7 +513,6 @@ pub fn prove_main_constraints_impl< gkr_circuit, } = composed_cs; - let num_instances = input.num_instances(); let log2_num_instances = input.log2_num_instances(); let num_threads = optimal_sumcheck_threads(log2_num_instances); let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); @@ -539,38 +538,12 @@ pub fn prove_main_constraints_impl< } let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); let group_stage_masks = first_layer_output_group_stage_masks(composed_cs, gkr_circuit); - let selector_ctxs = first_layer - .out_sel_and_eval_exprs - .iter() - .zip_eq(group_stage_masks.iter()) - .map(|((selector, _), stage_mask)| { - if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() { - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - } - } else if cs.r_selector.as_ref() == Some(selector) { - SelectorContext { - offset: 0, - num_instances: input.num_instances[0], - num_vars: num_var_with_rotation, - } - } else if cs.w_selector.as_ref() == Some(selector) { - SelectorContext { - offset: input.num_instances[0], - num_instances: input.num_instances[1], - num_vars: num_var_with_rotation, - } - } else { - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - } - } - }) - .collect_vec(); + let selector_ctxs = first_layer_selector_contexts( + composed_cs, + gkr_circuit, + input.num_instances, + num_var_with_rotation, + ); let mut out_evals = vec![PointAndEval::new(rt_tower.clone(), E::ZERO); gkr_circuit.n_evaluations]; @@ -2643,11 +2616,7 @@ impl> let mut max_num_variables = 0usize; for job in &jobs { - let ComposedConstrainSystem { - zkvm_v1_css: cs, - gkr_circuit, - } = job.cs; - let num_instances = job.input.num_instances(); + let ComposedConstrainSystem { gkr_circuit, .. } = job.cs; let log2_num_instances = job.input.log2_num_instances(); let num_var_with_rotation = log2_num_instances + job.cs.rotation_vars().unwrap_or(0); max_num_variables = max_num_variables.max(num_var_with_rotation); @@ -2656,40 +2625,12 @@ impl> panic!("empty gkr circuit") }; let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); - let group_stage_masks = first_layer_output_group_stage_masks(job.cs, gkr_circuit); - let selector_ctxs = first_layer - .out_sel_and_eval_exprs - .iter() - .zip_eq(group_stage_masks.iter()) - .map(|((selector, _), stage_mask)| { - if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() - { - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - } - } else if cs.r_selector.as_ref() == Some(selector) { - SelectorContext { - offset: 0, - num_instances: job.input.num_instances[0], - num_vars: num_var_with_rotation, - } - } else if cs.w_selector.as_ref() == Some(selector) { - SelectorContext { - offset: job.input.num_instances[0], - num_instances: job.input.num_instances[1], - num_vars: num_var_with_rotation, - } - } else { - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - } - } - }) - .collect_vec(); + let selector_ctxs = first_layer_selector_contexts( + job.cs, + gkr_circuit, + job.input.num_instances, + num_var_with_rotation, + ); let mut out_evals = vec![PointAndEval::new(job.rt_tower.clone(), E::ZERO); gkr_circuit.n_evaluations]; diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 4921d7f8c..b21d167e2 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -15,6 +15,7 @@ use gkr_iop::{ layer::{LayerWitness, ROTATION_OPENING_COUNT}, }, hal::{MultilinearPolynomial, ProtocolWitnessGeneratorProver, ProverBackend}, + selector::{SelectorContext, SelectorType}, }; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -161,6 +162,41 @@ pub(crate) fn first_layer_output_group_stage_masks( group_masks } +pub(crate) fn first_layer_selector_contexts( + composed_cs: &ComposedConstrainSystem, + circuit: &GKRCircuit, + num_instances: [usize; 2], + num_vars: usize, +) -> Vec { + let cs = &composed_cs.zkvm_v1_css; + let total_num_instances = num_instances.iter().sum(); + let first_layer = circuit.layers.first().expect("empty gkr circuit layer"); + let group_stage_masks = first_layer_output_group_stage_masks(composed_cs, circuit); + let distinct_rw_selectors = + cs.r_selector.is_some() && cs.w_selector.is_some() && cs.r_selector != cs.w_selector; + + first_layer + .out_sel_and_eval_exprs + .iter() + .zip_eq(group_stage_masks.iter()) + .map(|((selector, _), stage_mask)| { + if stage_mask.contains(GkrOutputStageMask::TOWER) + && distinct_rw_selectors + && matches!(selector, SelectorType::Prefix(_)) + { + if cs.r_selector.as_ref() == Some(selector) { + return SelectorContext::new(0, num_instances[0], num_vars); + } + if cs.w_selector.as_ref() == Some(selector) { + return SelectorContext::new(num_instances[0], num_instances[1], num_vars); + } + } + + SelectorContext::new(0, total_num_instances, num_vars) + }) + .collect_vec() +} + pub(crate) struct EccBridgeClaims { pub(crate) xy_point: Point, pub(crate) s_point: Point, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 76ed4ccb0..c9c3babe8 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -18,10 +18,7 @@ use crate::{ scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, septic_curve::{SepticExtension, SepticPoint}, - utils::{ - GkrOutputStageMask, assign_group_evals, derive_ecc_bridge_claims, - first_layer_output_group_stage_masks, - }, + utils::{assign_group_evals, derive_ecc_bridge_claims, first_layer_selector_contexts}, }, structs::{ ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VK_DIGEST_LEN, @@ -1016,27 +1013,12 @@ impl> let first_layer = gkr_circuit.layers.first().ok_or_else(|| { ZKVMError::InvalidProof(format!("{_name} empty gkr circuit layers").into()) })?; - let group_stage_masks = first_layer_output_group_stage_masks(composed_cs, gkr_circuit); - let selector_ctxs = first_layer - .out_sel_and_eval_exprs - .iter() - .zip_eq(group_stage_masks.iter()) - .map(|((selector, _), stage_mask)| { - if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() { - SelectorContext::new(0, num_instances, num_var_with_rotation) - } else if cs.r_selector.as_ref() == Some(selector) { - SelectorContext::new(0, proof.num_instances[0], num_var_with_rotation) - } else if cs.w_selector.as_ref() == Some(selector) { - SelectorContext::new( - proof.num_instances[0], - proof.num_instances[1], - num_var_with_rotation, - ) - } else { - SelectorContext::new(0, num_instances, num_var_with_rotation) - } - }) - .collect_vec(); + let selector_ctxs = first_layer_selector_contexts( + composed_cs, + gkr_circuit, + proof.num_instances, + num_var_with_rotation, + ); let mut out_evals = vec![PointAndEval::new(rt_main.clone(), E::ZERO); gkr_circuit.n_evaluations]; diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 0bd430063..462226ded 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -5,8 +5,8 @@ use crate::{ instructions::Instruction, scheme::septic_curve::SepticPoint, tables::{ - ECPoint, MemFinalRecord, RMMCollections, ShardRamCircuit, ShardRamInput, ShardRamRecord, - TableCircuit, + ECPoint, MemFinalRecord, RMMCollections, ShardRamCircuit, ShardRamEcTreeCircuit, + ShardRamInput, ShardRamRecord, TableCircuit, }, witness::LkMultiplicity, }; @@ -111,6 +111,7 @@ pub type RAMType = gkr_iop::RAMType; #[repr(u16)] pub enum CustomRWTag { KeccakState = 0, + ShardRamEcPoint = 1, } impl CustomRWTag { @@ -569,6 +570,7 @@ impl ZKVMWitnesses { &[(&'static str, Option>, &[MemFinalRecord])], ), config: & as TableCircuit>::TableConfig, + ec_tree_config: & as TableCircuit>::TableConfig, ) -> Result<(), ZKVMError> { use tracing::info_span; @@ -576,7 +578,13 @@ impl ZKVMWitnesses { // Only when GPU witgen is enabled (otherwise witgen must not touch GPU). #[cfg(feature = "gpu")] if crate::instructions::gpu::config::is_gpu_witgen_enabled() { - let gpu_result = self.try_assign_shared_circuit_gpu(cs, shard_ctx, final_mem, config); + let gpu_result = self.try_assign_shared_circuit_gpu( + cs, + shard_ctx, + final_mem, + config, + ec_tree_config, + ); match gpu_result { Ok(true) => return Ok(()), Ok(false) => {} /* GPU pipeline unavailable (no shared buffers), fall through to CPU */ @@ -709,24 +717,30 @@ impl ZKVMWitnesses { } assert!(self.combined_lk_mlt.is_none()); - let cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); + let shard_ram_cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); + let ec_tree_cs = cs + .get_cs(&ShardRamEcTreeCircuit::::name()) + .expect("ShardRamEcTreeCircuit must be registered"); let n_global = global_input.len(); // `ShardRamCircuit::assign_instances` ignores the `multiplicity` // argument (its lookup contribution is derived externally above), so // an empty slice is sufficient here and matches the pre-finalize // ordering: `combined_lk_mlt` is intentionally `None` at this point. let lk_multiplicity = LkMultiplicity::default(); - let circuit_inputs = + let (circuit_inputs, ec_tree_circuit_inputs) = info_span!("shard_ram_assign_instances", n = n_global).in_scope(|| { if global_input.is_empty() { - return Ok::>, ZKVMError>(vec![]); + return Ok::<(Vec>, Vec>), ZKVMError>(( + vec![], + vec![], + )); } let mut lk_multiplicity = lk_multiplicity.clone(); let witness = ShardRamCircuit::assign_instances_with_lk_multiplicities( config, - cs.zkvm_v1_css.num_witin as usize, - cs.zkvm_v1_css.num_structural_witin as usize, + shard_ram_cs.zkvm_v1_css.num_witin as usize, + shard_ram_cs.zkvm_v1_css.num_structural_witin as usize, &mut lk_multiplicity, &global_input, )?; @@ -736,11 +750,37 @@ impl ZKVMWitnesses { .count(); let num_writes = global_input.len() - num_reads; - Ok(vec![ChipInput::new( - ShardRamCircuit::::name(), - witness, - [num_reads, num_writes], - )]) + let ec_tree_input = global_input + .iter() + .filter(|access| !access.record.is_to_write_set) + .chain( + global_input + .iter() + .filter(|access| access.record.is_to_write_set), + ) + .cloned() + .collect_vec(); + let ec_tree_witness = + ShardRamEcTreeCircuit::assign_instances_with_lk_multiplicities( + ec_tree_config, + ec_tree_cs.zkvm_v1_css.num_witin as usize, + ec_tree_cs.zkvm_v1_css.num_structural_witin as usize, + &mut LkMultiplicity::default(), + &ec_tree_input, + )?; + + Ok(( + vec![ChipInput::new( + ShardRamCircuit::::name(), + witness, + [num_reads, num_writes], + )], + vec![ChipInput::new( + ShardRamEcTreeCircuit::::name(), + ec_tree_witness, + [num_writes, num_reads], + )], + )) })?; assert!( @@ -757,6 +797,11 @@ impl ZKVMWitnesses { .insert(ShardRamCircuit::::name(), circuit_inputs) .is_none() ); + assert!( + self.witnesses + .insert(ShardRamEcTreeCircuit::::name(), ec_tree_circuit_inputs) + .is_none() + ); Ok(()) } @@ -776,25 +821,35 @@ impl ZKVMWitnesses { shard_ctx: &ShardContext, final_mem: &[(&'static str, Option>, &[MemFinalRecord])], config: & as TableCircuit>::TableConfig, + ec_tree_config: & as TableCircuit>::TableConfig, ) -> Result { assert!(self.combined_lk_mlt.is_none()); - let cs_inner = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); - let num_witin = cs_inner.zkvm_v1_css.num_witin as usize; - let num_structural_witin = cs_inner.zkvm_v1_css.num_structural_witin as usize; + let shard_ram_cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); + let ec_tree_cs = cs + .get_cs(&ShardRamEcTreeCircuit::::name()) + .expect("ShardRamEcTreeCircuit must be registered"); match crate::instructions::gpu::chips::shard_ram::try_gpu_assign_shared_circuit::( shard_ctx, final_mem, config, - num_witin, - num_structural_witin, + ec_tree_config, + shard_ram_cs.zkvm_v1_css.num_witin as usize, + shard_ram_cs.zkvm_v1_css.num_structural_witin as usize, + ec_tree_cs.zkvm_v1_css.num_witin as usize, + ec_tree_cs.zkvm_v1_css.num_structural_witin as usize, )? { - Some((circuit_inputs, lk_mlt)) => { + Some((circuit_inputs, ec_tree_circuit_inputs, lk_mlt)) => { assert!( self.witnesses .insert(ShardRamCircuit::::name(), circuit_inputs) .is_none() ); + assert!( + self.witnesses + .insert(ShardRamEcTreeCircuit::::name(), ec_tree_circuit_inputs) + .is_none() + ); assert!( self.lk_mlts .insert(ShardRamCircuit::::name(), lk_mlt) diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index de033aefd..7e05bc5db 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -8,7 +8,7 @@ use crate::{ gadgets::Poseidon2Config, instructions::riscv::constants::UINT_LIMBS, scheme::septic_curve::{SepticExtension, SepticPoint}, - structs::{ProgramParams, RAMType}, + structs::{CustomRWTag, ProgramParams, RAMType}, tables::{RMMCollections, TableCircuit}, witness::LkMultiplicity, }; @@ -30,8 +30,8 @@ use p3::{ }; use rayon::{ iter::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelExtend, - ParallelIterator, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelExtend, ParallelIterator, }, prelude::ParallelSliceMut, slice::ParallelSlice, @@ -43,6 +43,14 @@ use crate::{instructions::riscv::constants::UInt, scheme::constants::SEPTIC_EXTE pub(crate) const Y6_LO_TOP_BYTE_LT_BOUND: u64 = 60; +fn shard_ram_ec_point_record(x: &[WitIn], y: &[WitIn]) -> Vec> { + [CustomRWTag::ShardRamEcPoint.expr::()] + .into_iter() + .chain(x.iter().map(|w| w.expr())) + .chain(y.iter().map(|w| w.expr())) + .collect() +} + /// A record for a read/write into the shard RAM #[derive(Debug, Clone)] pub struct ShardRamRecord { @@ -187,7 +195,6 @@ pub struct ShardRamConfig { pub(crate) is_global_write: WitIn, pub(crate) x: Vec, pub(crate) y: Vec, - pub(crate) slope: Vec, // Byte limbs of `y6_lo`, the helper that binds `y[SEPTIC_EXTENSION_DEGREE - 1]` // to `is_global_write` in `configure`. pub(crate) y6_lo_bytes: [WitIn; 4], @@ -203,9 +210,6 @@ impl ShardRamConfig { let y: Vec = (0..SEPTIC_EXTENSION_DEGREE) .map(|i| cb.create_witin(|| format!("y{}", i))) .collect(); - let slope: Vec = (0..SEPTIC_EXTENSION_DEGREE) - .map(|i| cb.create_witin(|| format!("slope{}", i))) - .collect(); let addr = cb.create_witin(|| "addr"); let is_ram_register = cb.create_witin(|| "is_ram_register"); let value = UInt::new_unchecked(|| "value", cb)?; @@ -261,19 +265,22 @@ impl ShardRamConfig { )?; cb.write_rlc_record( || "w_record", - ram_type, + ram_type.clone(), record.clone(), cb.rlc_chip_record(record), )?; - // enforces final_sum = \sum_i (x_i, y_i) using ecc quark protocol - let final_sum = cb.query_global_rw_sum()?; - cb.ec_sum( - x.iter().map(|xi| xi.expr()).collect::>(), - y.iter().map(|yi| yi.expr()).collect::>(), - slope.iter().map(|si| si.expr()).collect::>(), - final_sum.into_iter().map(|x| x.expr()).collect::>(), - ); + let ec_point_record = shard_ram_ec_point_record(&x, &y); + cb.read_record( + || "shard_ram_ec_point_in", + RAMType::Custom, + ec_point_record.clone(), + )?; + cb.write_record( + || "shard_ram_ec_point_out", + RAMType::Custom, + ec_point_record, + )?; // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, global_clk, nonce, 0, ..., 0]) for (input_expr, hasher_input) in input.into_iter().zip_eq(perm_config.inputs().into_iter()) @@ -327,7 +334,6 @@ impl ShardRamConfig { Ok(ShardRamConfig { x, y, - slope, addr, is_ram_register, value, @@ -342,6 +348,54 @@ impl ShardRamConfig { } } +pub struct ShardRamEcTreeConfig { + pub(crate) x: Vec, + pub(crate) y: Vec, + pub(crate) slope: Vec, + _marker: PhantomData, +} + +impl ShardRamEcTreeConfig { + pub fn configure(cb: &mut CircuitBuilder) -> Result { + let x: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("x{i}"))) + .collect(); + let y: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("y{i}"))) + .collect(); + let slope: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("slope{i}"))) + .collect(); + + let ec_point_record = shard_ram_ec_point_record(&x, &y); + cb.read_record( + || "shard_ram_ec_point_in", + RAMType::Custom, + ec_point_record.clone(), + )?; + cb.write_record( + || "shard_ram_ec_point_out", + RAMType::Custom, + ec_point_record, + )?; + + let final_sum = cb.query_global_rw_sum()?; + cb.ec_sum( + x.iter().map(|xi| xi.expr()).collect::>(), + y.iter().map(|yi| yi.expr()).collect::>(), + slope.iter().map(|si| si.expr()).collect::>(), + final_sum.into_iter().map(|x| x.expr()).collect::>(), + ); + + Ok(Self { + x, + y, + slope, + _marker: PhantomData, + }) + } +} + /// This chip is used to manage read/write into a global set /// shared among multiple shards #[derive(Default)] @@ -356,6 +410,11 @@ pub struct ShardRamInput { pub ec_point: ECPoint, } +#[derive(Default)] +pub struct ShardRamEcTreeCircuit { + _marker: PhantomData, +} + /// Decode `y6_lo` (the byte-decomposed helper bound to `is_global_write` in /// `ShardRamConfig::configure`) from a witnessed `y6` field element. Mirrors /// the prover-side derivation done inside the per-row witness assignment; @@ -441,10 +500,10 @@ impl ShardRamCircuit { input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk); input[2 + k + 2] = E::BaseField::from_canonical_u32(*nonce); - config - .perm_config - // TODO: remove hardcoded constant 28 - .assign_instance(&mut instance[28 + UINT_LIMBS..], input); + config.perm_config.assign_instance( + &mut instance[config.perm_config.p3_cols[0].id as usize..], + input, + ); Ok(()) } @@ -496,11 +555,6 @@ impl TableCircuit for ShardRamCircuit { let selector_r = cb.create_placeholder_structural_witin(|| "selector_r"); let selector_w = cb.create_placeholder_structural_witin(|| "selector_w"); let selector_zero = cb.create_placeholder_structural_witin(|| "selector_zero"); - let selector_ecc_x = cb.create_placeholder_structural_witin(|| "selector_ecc_x"); - let selector_ecc_y = cb.create_placeholder_structural_witin(|| "selector_ecc_y"); - let selector_ecc_s = cb.create_placeholder_structural_witin(|| "selector_ecc_s"); - let selector_ecc_x3 = cb.create_placeholder_structural_witin(|| "selector_ecc_x3"); - let selector_ecc_y3 = cb.create_placeholder_structural_witin(|| "selector_ecc_y3"); let config = Self::construct_circuit(cb, param)?; @@ -522,14 +576,6 @@ impl TableCircuit for ShardRamCircuit { cb.cs.w_selector = Some(selector_w); cb.cs.zero_selector = Some(selector_zero.clone()); cb.cs.lk_selector = Some(selector_zero); - cb.cs.ec_bridge_selectors = Some([ - SelectorType::Whole(selector_ecc_x.expr()), - SelectorType::Whole(selector_ecc_y.expr()), - SelectorType::Whole(selector_ecc_s.expr()), - SelectorType::Whole(selector_ecc_x3.expr()), - SelectorType::Whole(selector_ecc_y3.expr()), - ]); - // all shared the same selector let (out_evals, mut chip) = ( [ @@ -575,7 +621,7 @@ impl TableCircuit for ShardRamCircuit { } #[cfg(feature = "gpu")] - { + if crate::instructions::gpu::config::is_gpu_witgen_enabled() { if let Some(result) = Self::try_gpu_assign_instances( config, num_witin, @@ -590,20 +636,13 @@ impl TableCircuit for ShardRamCircuit { // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` - // ShardRam expects exactly these structural selectors: - // r, w, zero, ecc_x, ecc_y, ecc_s, ecc_x3, ecc_y3. assert_eq!( - num_structural_witin, 8, - "ShardRam requires exactly 8 structural selectors (r,w,zero,ecc_x,ecc_y,ecc_s,ecc_x3,ecc_y3)" + num_structural_witin, 3, + "ShardRam leaf requires r, w, and zero structural selectors" ); let selector_r_witin = WitIn { id: 0 }; let selector_w_witin = WitIn { id: 1 }; let selector_zero_witin = WitIn { id: 2 }; - let selector_ecc_x_witin = WitIn { id: 3 }; - let selector_ecc_y_witin = WitIn { id: 4 }; - let selector_ecc_s_witin = WitIn { id: 5 }; - let selector_ecc_x3_witin = WitIn { id: 6 }; - let selector_ecc_y3_witin = WitIn { id: 7 }; let nthreads = max_usable_threads(); @@ -626,10 +665,7 @@ impl TableCircuit for ShardRamCircuit { .max(1); let n = next_pow2_instance_padding(steps.len()); - // compute the input for the binary tree for ec point summation - - // *2 because we need to store the internal nodes of binary tree for ec point summation - let num_rows_padded = 2 * n; + let num_rows_padded = n; let mut raw_witin = { let matrix_size = num_rows_padded * num_witin; @@ -651,17 +687,6 @@ impl TableCircuit for ShardRamCircuit { ); RowMajorMatrix::new(value, num_structural_witin) }; - // ECC bridge selectors are `Whole`, so keep them active on all rows. - raw_structual_witin - .values - .par_chunks_mut(num_structural_witin) - .for_each(|row| { - set_val!(row, selector_ecc_x_witin, E::BaseField::ONE); - set_val!(row, selector_ecc_y_witin, E::BaseField::ONE); - set_val!(row, selector_ecc_s_witin, E::BaseField::ONE); - set_val!(row, selector_ecc_x3_witin, E::BaseField::ONE); - set_val!(row, selector_ecc_y3_witin, E::BaseField::ONE); - }); let raw_witin_iter = raw_witin.values[0..steps.len() * num_witin] .par_chunks_mut(num_instance_per_batch * num_witin); let raw_structual_witin_iter = raw_structual_witin.values @@ -695,6 +720,205 @@ impl TableCircuit for ShardRamCircuit { }) .collect::>()?; + let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix( + raw_witin, + InstancePaddingStrategy::Default, + ); + let raw_structual_witin = witness::RowMajorMatrix::new_by_inner_matrix( + raw_structual_witin, + InstancePaddingStrategy::Default, + ); + Ok([raw_witin, raw_structual_witin]) + } +} + +impl ShardRamEcTreeCircuit { + fn assign_leaf_instance( + config: &ShardRamEcTreeConfig, + instance: &mut [E::BaseField], + input: &ShardRamInput, + ) { + config + .x + .iter() + .chain(config.y.iter()) + .zip_eq( + input + .ec_point + .point + .x + .deref() + .iter() + .chain(input.ec_point.point.y.deref().iter()), + ) + .for_each(|(witin, fe)| { + set_val!(instance, *witin, *fe); + }); + } + + pub fn extract_ec_sum( + config: &ShardRamEcTreeConfig, + rmm: &witness::RowMajorMatrix<::BaseField>, + ) -> SepticPoint<::BaseField> { + assert!(rmm.height() >= 2); + let instance = &rmm[rmm.height() - 2]; + + let xy = config + .x + .iter() + .chain(config.y.iter()) + .map(|witin| instance[witin.id as usize]) + .collect_vec(); + + let x: SepticExtension = xy[0..SEPTIC_EXTENSION_DEGREE].into(); + let y: SepticExtension = xy[SEPTIC_EXTENSION_DEGREE..].into(); + + SepticPoint::from_affine(x, y) + } +} + +impl TableCircuit for ShardRamEcTreeCircuit { + type TableConfig = ShardRamEcTreeConfig; + type FixedInput = (); + type WitnessInput<'a> = [ShardRamInput]; + + fn name() -> String { + "ShardRamEcTreeCircuit".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _param: &ProgramParams, + ) -> Result { + Ok(ShardRamEcTreeConfig::configure(cb)?) + } + + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), ZKVMError> { + let selector_r = cb.create_placeholder_structural_witin(|| "selector_r"); + let selector_w = cb.create_placeholder_structural_witin(|| "selector_w"); + let selector_ecc_x = cb.create_placeholder_structural_witin(|| "selector_ecc_x"); + let selector_ecc_y = cb.create_placeholder_structural_witin(|| "selector_ecc_y"); + let selector_ecc_s = cb.create_placeholder_structural_witin(|| "selector_ecc_s"); + let selector_ecc_x3 = cb.create_placeholder_structural_witin(|| "selector_ecc_x3"); + let selector_ecc_y3 = cb.create_placeholder_structural_witin(|| "selector_ecc_y3"); + + let config = Self::construct_circuit(cb, param)?; + + let w_len = cb.cs.w_expressions.len(); + let r_len = cb.cs.r_expressions.len(); + let lk_len = cb.cs.lk_expressions.len(); + let zero_len = + cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); + + cb.cs.r_selector = Some(SelectorType::Prefix(selector_r.expr())); + cb.cs.w_selector = Some(SelectorType::Prefix(selector_w.expr())); + cb.cs.ec_bridge_selectors = Some([ + SelectorType::Whole(selector_ecc_x.expr()), + SelectorType::Whole(selector_ecc_y.expr()), + SelectorType::Whole(selector_ecc_s.expr()), + SelectorType::Whole(selector_ecc_x3.expr()), + SelectorType::Whole(selector_ecc_y3.expr()), + ]); + + let (out_evals, mut chip) = ( + [ + (0..r_len).collect_vec(), + (r_len..r_len + w_len).collect_vec(), + (r_len + w_len..r_len + w_len + lk_len).collect_vec(), + (0..zero_len).collect_vec(), + ], + Chip::new_from_cb(cb), + ); + + let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), out_evals); + chip.add_layer(layer); + + Ok((config, Some(chip.gkr_circuit()))) + } + + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _input: &Self::FixedInput, + ) -> witness::RowMajorMatrix<::BaseField> { + unimplemented!() + } + + fn assign_instances_with_lk_multiplicities( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _lk_multiplicity: &mut LkMultiplicity, + steps: &Self::WitnessInput<'_>, + ) -> Result, ZKVMError> { + if steps.is_empty() { + return Ok([ + witness::RowMajorMatrix::empty(), + witness::RowMajorMatrix::empty(), + ]); + } + + assert_eq!( + num_structural_witin, 7, + "ShardRam EC tree requires r, w, and 5 EC bridge selectors" + ); + let selector_r_witin = WitIn { id: 0 }; + let selector_w_witin = WitIn { id: 1 }; + + let n = next_pow2_instance_padding(steps.len()); + let num_rows_padded = 2 * n; + + let mut raw_witin = { + let matrix_size = num_rows_padded * num_witin; + let mut value = Vec::with_capacity(matrix_size); + value.par_extend( + (0..matrix_size) + .into_par_iter() + .map(|_| E::BaseField::default()), + ); + RowMajorMatrix::new(value, num_witin) + }; + let mut raw_structual_witin = { + let matrix_size = num_rows_padded * num_structural_witin; + let mut value = Vec::with_capacity(matrix_size); + value.par_extend( + (0..matrix_size) + .into_par_iter() + .map(|_| E::BaseField::default()), + ); + RowMajorMatrix::new(value, num_structural_witin) + }; + + raw_structual_witin + .values + .par_chunks_mut(num_structural_witin) + .for_each(|row| { + row[2..7].fill(E::BaseField::ONE); + }); + + let num_custom_reads = steps + .iter() + .take_while(|step| !step.record.is_to_write_set) + .count(); + raw_structual_witin.values[0..steps.len() * num_structural_witin] + .par_chunks_mut(num_structural_witin) + .enumerate() + .for_each(|(row_idx, row)| { + if row_idx < num_custom_reads { + set_val!(row, selector_r_witin, E::BaseField::ONE); + } else { + set_val!(row, selector_w_witin, E::BaseField::ONE); + } + }); + + raw_witin.values[0..steps.len() * num_witin] + .par_chunks_mut(num_witin) + .zip_eq(steps.par_iter()) + .for_each(|(instance, step)| Self::assign_leaf_instance(config, instance, step)); + // allocate num_rows_padded size, fill points on first half let mut cur_layer_points_buffer: Vec<_> = (0..num_rows_padded) .into_par_iter() @@ -808,9 +1032,13 @@ impl ShardRamCircuit { mod tests { use either::Either; use ff_ext::{BabyBearExt4, FromUniformBytes, PoseidonField}; + use gkr_iop::cpu::{CpuBackend, CpuProver}; use itertools::Itertools; use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; - use p3::babybear::BabyBear; + use p3::{ + babybear::BabyBear, + field::{FieldAlgebra, PrimeField32}, + }; use rand::thread_rng; use std::sync::Arc; use tracing_forest::{ForestLayer, util::LevelFilter}; @@ -821,26 +1049,326 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, scheme::{ - PublicValues, constants::SEPTIC_EXTENSION_DEGREE, create_backend, create_prover, - hal::ProofInput, mock_prover::MockProver, prover::ZKVMProver, + PublicValues, + constants::SEPTIC_EXTENSION_DEGREE, + create_backend, create_prover, + hal::ProofInput, + mock_prover::MockProver, + prover::ZKVMProver, septic_curve::SepticPoint, + utils::{WitnessBuildStage, build_main_witness, first_layer_selector_contexts}, }, structs::{ComposedConstrainSystem, ProgramParams, RAMType, ZKVMProvingKey}, - tables::{ShardRamCircuit, ShardRamInput, ShardRamRecord, TableCircuit}, + tables::{ + RMMCollections, ShardRamCircuit, ShardRamEcTreeCircuit, ShardRamInput, ShardRamRecord, + TableCircuit, + }, witness::LkMultiplicity, }; #[cfg(feature = "gpu")] - use gkr_iop::{ - gpu::{MultilinearExtensionGpu, get_cuda_hal}, - hal::MultilinearPolynomial, - }; - use p3::field::PrimeField32; + use gkr_iop::gpu::{MultilinearExtensionGpu, get_cuda_hal}; type E = BabyBearExt4; type F = BabyBear; type Perm = ::P; type Pcs = BasefoldDefault; + fn shard_ram_test_inputs(read_count: usize, write_count: usize) -> Vec> { + let perm = ::get_default_perm(); + let reads = (0..read_count).map(|i| ShardRamRecord { + addr: (0x1000 + i * 4) as u32, + ram_type: RAMType::Memory, + value: (0x2000 + i) as u32, + shard: 1, + local_clk: i as u64 + 1, + global_clk: i as u64 + 10, + is_to_write_set: true, + }); + let writes = (0..write_count).map(|i| ShardRamRecord { + addr: (0x2000 + i * 4) as u32, + ram_type: RAMType::Memory, + value: (0x3000 + i) as u32, + shard: 2, + local_clk: 0, + global_clk: i as u64 + 20, + is_to_write_set: false, + }); + + reads + .chain(writes) + .map(|record| { + let ec_point = record.to_ec_point::(&perm); + ShardRamInput { + name: "selector_test", + record, + ec_point, + } + }) + .collect_vec() + } + + fn assert_selector_column( + witness: &witness::RowMajorMatrix, + col: usize, + ones: std::ops::Range, + ) { + for row in 0..witness.height() { + let expected = if ones.contains(&row) { F::ONE } else { F::ZERO }; + assert_eq!(witness[row][col], expected, "selector col {col} row {row}"); + } + } + + fn assert_column_is_binary(witness: &witness::RowMajorMatrix, col: usize) { + for row in 0..witness.height() { + let value = witness[row][col]; + assert!( + value == F::ZERO || value == F::ONE, + "selector col {col} row {row} is not binary: {value}" + ); + } + } + + fn proof_input_for_witness<'a>( + cs: &ConstraintSystem, + witness: &'a RMMCollections, + num_instances: [usize; 2], + has_ecc_ops: bool, + public_value: &PublicValues, + ) -> ProofInput<'a, CpuBackend> { + let witness_mles = witness[0].to_mles().into_iter().map(Arc::new).collect_vec(); + let structural_mles = witness[1].to_mles().into_iter().map(Arc::new).collect_vec(); + let pub_io_evals = cs + .instance + .iter() + .map(|instance| Either::Right(E::from(public_value.query_by_index::(instance.0)))) + .collect_vec(); + + ProofInput { + witness: witness_mles, + structural_witness: structural_mles, + fixed: vec![], + pi: pub_io_evals, + num_instances, + has_ecc_ops, + } + } + + fn assert_inactive_rows_are_one( + records: &[Arc>], + range: std::ops::Range, + inactive_rows: impl IntoIterator, + ) { + let inactive_rows = inactive_rows.into_iter().collect_vec(); + for record_idx in range { + let evals = records[record_idx].get_ext_field_vec(); + for &row in &inactive_rows { + assert_eq!(evals[row], E::ONE, "record {record_idx} row {row}"); + } + } + } + + fn assert_record_rows_match( + left: &Arc>, + left_rows: std::ops::Range, + right: &Arc>, + right_rows: std::ops::Range, + label: &str, + ) { + assert_eq!(left_rows.len(), right_rows.len(), "{label} row count"); + let left_evals = left.get_ext_field_vec(); + let right_evals = right.get_ext_field_vec(); + for (left_row, right_row) in left_rows.zip(right_rows) { + assert_eq!( + left_evals[left_row], right_evals[right_row], + "{label}: left row {left_row}, right row {right_row}" + ); + } + } + + #[test] + fn test_shard_ram_split_selectors_and_tower_padding() { + let read_count = 2; + let write_count = 3; + let input = shard_ram_test_inputs(read_count, write_count); + let ec_tree_input = input + .iter() + .filter(|access| !access.record.is_to_write_set) + .chain(input.iter().filter(|access| access.record.is_to_write_set)) + .cloned() + .collect_vec(); + + let global_ec_sum: SepticPoint = input + .iter() + .map(|record| record.ec_point.point.clone()) + .sum(); + let mut shard_rw_sum = [0u32; SEPTIC_EXTENSION_DEGREE * 2]; + for (i, fe) in global_ec_sum + .x + .iter() + .chain(global_ec_sum.y.iter()) + .enumerate() + { + shard_rw_sum[i] = fe.as_canonical_u32(); + } + let public_value = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, [0; 8], shard_rw_sum); + + let mut leaf_cs = ConstraintSystem::new(|| "shard ram selector leaf"); + let mut leaf_cb = CircuitBuilder::new(&mut leaf_cs); + let (leaf_config, leaf_gkr_circuit) = + ShardRamCircuit::::build_gkr_iop_circuit(&mut leaf_cb, &ProgramParams::default()) + .unwrap(); + let leaf_witness = ShardRamCircuit::::assign_instances_with_lk_multiplicities( + &leaf_config, + leaf_cs.num_witin as usize, + leaf_cs.num_structural_witin as usize, + &mut LkMultiplicity::default(), + &input, + ) + .unwrap(); + + assert_selector_column(&leaf_witness[1], 0, 0..read_count); + assert_selector_column(&leaf_witness[1], 1, read_count..read_count + write_count); + assert_selector_column(&leaf_witness[1], 2, 0..read_count + write_count); + for col in 0..leaf_witness[1].width() { + assert_column_is_binary(&leaf_witness[1], col); + } + + let leaf_composed = ComposedConstrainSystem { + zkvm_v1_css: leaf_cs, + gkr_circuit: leaf_gkr_circuit, + }; + let leaf_gkr = leaf_composed.gkr_circuit.as_ref().unwrap(); + let leaf_selector_ctxs = + first_layer_selector_contexts(&leaf_composed, leaf_gkr, [read_count, write_count], 3); + assert_eq!(leaf_selector_ctxs[0].offset, 0); + assert_eq!(leaf_selector_ctxs[0].num_instances, read_count); + assert_eq!(leaf_selector_ctxs[1].offset, read_count); + assert_eq!(leaf_selector_ctxs[1].num_instances, write_count); + + let leaf_proof_input = proof_input_for_witness( + &leaf_composed.zkvm_v1_css, + &leaf_witness, + [read_count, write_count], + false, + &public_value, + ); + let leaf_records = + build_main_witness::, CpuProver>>( + &leaf_composed, + &leaf_proof_input, + &[E::ONE, E::from_canonical_u32(7)], + WitnessBuildStage::Tower, + ); + let leaf_r_len = leaf_composed.zkvm_v1_css.r_expressions.len() + + leaf_composed.zkvm_v1_css.r_table_expressions.len(); + let leaf_w_len = leaf_composed.zkvm_v1_css.w_expressions.len() + + leaf_composed.zkvm_v1_css.w_table_expressions.len(); + assert_inactive_rows_are_one( + &leaf_records, + 0..leaf_r_len, + read_count..leaf_witness[0].height(), + ); + assert_inactive_rows_are_one( + &leaf_records, + leaf_r_len..leaf_r_len + leaf_w_len, + (0..read_count).chain(read_count + write_count..leaf_witness[0].height()), + ); + + let mut ec_tree_cs = ConstraintSystem::new(|| "shard ram selector ec tree"); + let mut ec_tree_cb = CircuitBuilder::new(&mut ec_tree_cs); + let (ec_tree_config, ec_tree_gkr_circuit) = + ShardRamEcTreeCircuit::::build_gkr_iop_circuit( + &mut ec_tree_cb, + &ProgramParams::default(), + ) + .unwrap(); + let ec_tree_witness = ShardRamEcTreeCircuit::::assign_instances_with_lk_multiplicities( + &ec_tree_config, + ec_tree_cs.num_witin as usize, + ec_tree_cs.num_structural_witin as usize, + &mut LkMultiplicity::default(), + &ec_tree_input, + ) + .unwrap(); + + assert_selector_column(&ec_tree_witness[1], 0, 0..write_count); + assert_selector_column( + &ec_tree_witness[1], + 1, + write_count..write_count + read_count, + ); + for col in 0..ec_tree_witness[1].width() { + assert_column_is_binary(&ec_tree_witness[1], col); + } + for col in 2..ec_tree_witness[1].width() { + assert_selector_column(&ec_tree_witness[1], col, 0..ec_tree_witness[1].height()); + } + + let ec_tree_composed = ComposedConstrainSystem { + zkvm_v1_css: ec_tree_cs, + gkr_circuit: ec_tree_gkr_circuit, + }; + let ec_tree_gkr = ec_tree_composed.gkr_circuit.as_ref().unwrap(); + let ec_tree_selector_ctxs = first_layer_selector_contexts( + &ec_tree_composed, + ec_tree_gkr, + [write_count, read_count], + 4, + ); + assert_eq!(ec_tree_selector_ctxs[0].offset, 0); + assert_eq!(ec_tree_selector_ctxs[0].num_instances, write_count); + assert_eq!(ec_tree_selector_ctxs[1].offset, write_count); + assert_eq!(ec_tree_selector_ctxs[1].num_instances, read_count); + + let ec_tree_proof_input = proof_input_for_witness( + &ec_tree_composed.zkvm_v1_css, + &ec_tree_witness, + [write_count, read_count], + true, + &public_value, + ); + let ec_tree_records = + build_main_witness::, CpuProver>>( + &ec_tree_composed, + &ec_tree_proof_input, + &[E::ONE, E::from_canonical_u32(7)], + WitnessBuildStage::Tower, + ); + let ec_tree_r_len = ec_tree_composed.zkvm_v1_css.r_expressions.len() + + ec_tree_composed.zkvm_v1_css.r_table_expressions.len(); + let ec_tree_w_len = ec_tree_composed.zkvm_v1_css.w_expressions.len() + + ec_tree_composed.zkvm_v1_css.w_table_expressions.len(); + assert_inactive_rows_are_one( + &ec_tree_records, + 0..ec_tree_r_len, + write_count..ec_tree_witness[0].height(), + ); + assert_inactive_rows_are_one( + &ec_tree_records, + ec_tree_r_len..ec_tree_r_len + ec_tree_w_len, + (0..write_count).chain(write_count + read_count..ec_tree_witness[0].height()), + ); + + let leaf_custom_read = &leaf_records[leaf_r_len - 1]; + let leaf_custom_write = &leaf_records[leaf_r_len + leaf_w_len - 1]; + let ec_tree_custom_read = &ec_tree_records[ec_tree_r_len - 1]; + let ec_tree_custom_write = &ec_tree_records[ec_tree_r_len + ec_tree_w_len - 1]; + assert_record_rows_match( + leaf_custom_read, + 0..read_count, + ec_tree_custom_write, + write_count..write_count + read_count, + "leaf read vs ec-tree write", + ); + assert_record_rows_match( + leaf_custom_write, + read_count..read_count + write_count, + ec_tree_custom_read, + 0..write_count, + "leaf write vs ec-tree read", + ); + } + #[test] fn test_shard_ram_circuit() { // default filter @@ -940,10 +1468,39 @@ mod tests { ) .unwrap(); - // api extract ec sum from rmm witness + let mut ec_tree_cs = ConstraintSystem::new(|| "global ec tree chip test"); + let mut ec_tree_cb = CircuitBuilder::new(&mut ec_tree_cs); + let (ec_tree_config, _ec_tree_gkr_circuit) = ShardRamEcTreeCircuit::build_gkr_iop_circuit( + &mut ec_tree_cb, + &ProgramParams::default(), + ) + .unwrap(); + let ec_tree_input = input + .iter() + .filter(|access| !access.record.is_to_write_set) + .chain(input.iter().filter(|access| access.record.is_to_write_set)) + .cloned() + .collect_vec(); + let ec_tree_witness = ShardRamEcTreeCircuit::assign_instances_with_lk_multiplicities( + &ec_tree_config, + ec_tree_cb.cs.num_witin as usize, + ec_tree_cb.cs.num_structural_witin as usize, + &mut LkMultiplicity::default(), + &ec_tree_input, + ) + .unwrap(); + + // EC accumulation lives in the split EC tree chip. assert_eq!( global_ec_sum, - ShardRamCircuit::extract_ec_sum(&config, &witness[0]) + ShardRamEcTreeCircuit::extract_ec_sum(&ec_tree_config, &ec_tree_witness[0]) + ); + MockProver::::assert_satisfied_raw( + &ec_tree_cb, + ec_tree_witness.clone(), + &[], + Some([E::random(&mut thread_rng()), E::random(&mut thread_rng())]), + None, ); let composed_cs = ComposedConstrainSystem { @@ -1000,7 +1557,7 @@ mod tests { fixed: vec![], pi: pub_io_evals, num_instances: [n_global_writes as usize, n_global_reads as usize], - has_ecc_ops: true, + has_ecc_ops: false, }; let mut rng = thread_rng(); let challenges = [E::random(&mut rng), E::random(&mut rng)];