diff --git a/ceno_recursion_v2/src/bus.rs b/ceno_recursion_v2/src/bus.rs index 21562d770..c65e7adf3 100644 --- a/ceno_recursion_v2/src/bus.rs +++ b/ceno_recursion_v2/src/bus.rs @@ -1,4 +1,3 @@ -use openvm_poseidon2_air::POSEIDON2_WIDTH; use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; use recursion_circuit::{ bus as upstream, define_typed_per_proof_lookup_bus, define_typed_per_proof_permutation_bus, @@ -19,7 +18,7 @@ pub use upstream::{ #[repr(C)] #[derive(stark_recursion_circuit_derive::AlignedBorrow, Debug, Clone, Copy)] pub struct ForkedTranscriptBusMessage { - /// Fork identifier (1-based). Matches TranscriptAir's fork_id column. + /// Fork identifier (0-based). Matches TranscriptAir's fork_id column. pub fork_id: T, /// Position within the fork transcript namespace. pub tidx: T, @@ -57,17 +56,31 @@ define_typed_per_proof_lookup_bus!(LookupChallengeBus, LookupChallengeMessage); #[repr(C)] #[derive(stark_recursion_circuit_derive::AlignedBorrow, Debug, Clone, Copy)] pub struct TowerModuleMessage { - pub idx: T, - pub tidx: T, - pub n_logup: T, + pub chip_id: T, + pub num_layers: T, + pub num_read_specs: T, + pub num_write_specs: T, + pub num_logup_specs: T, } define_typed_per_proof_permutation_bus!(TowerModuleBus, TowerModuleMessage); +#[repr(C)] +#[derive(stark_recursion_circuit_derive::AlignedBorrow, Debug, Clone, Copy)] +pub struct TowerRootClaimMessage { + pub chip_id: T, + pub r0_claim: [T; D_EF], + pub w0_claim: [T; D_EF], + pub p0_claim: [T; D_EF], + pub q0_claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerRootClaimBus, TowerRootClaimMessage); + #[repr(C)] #[derive(stark_recursion_circuit_derive::AlignedBorrow, Debug, Clone, Copy)] pub struct MainMessage { - pub idx: T, + pub chip_id: T, pub tidx: T, pub claim: [T; D_EF], } diff --git a/ceno_recursion_v2/src/circuit/inner/mod.rs b/ceno_recursion_v2/src/circuit/inner/mod.rs index 2d66e353d..a754f9b17 100644 --- a/ceno_recursion_v2/src/circuit/inner/mod.rs +++ b/ceno_recursion_v2/src/circuit/inner/mod.rs @@ -29,6 +29,8 @@ pub use trace::*; pub struct InnerCircuit { pub verifier_circuit: Arc, pub def_hook_commit: Option, + pub has_fixed_commit: bool, + pub has_fixed_no_omc_init_commit: bool, pub instance_public_value_indices: Arc>>, } @@ -67,6 +69,8 @@ impl, S: AggregationSubCircuit> Circuit for I lookup_challenge_bus, pvs_air_consistency_bus, deferral_enabled, + has_fixed_commit: self.has_fixed_commit, + has_fixed_no_omc_init_commit: self.has_fixed_no_omc_init_commit, instance_public_value_indices: self.instance_public_value_indices.clone(), }) as AirRef; diff --git a/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs b/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs index 63ae7727d..1fa16f783 100644 --- a/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs +++ b/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs @@ -1,10 +1,13 @@ use std::{borrow::Borrow, sync::Arc}; use ceno_emul::{FullTracer as Tracer, WORD_SIZE}; -use ceno_zkvm::instructions::riscv::constants::{ - END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, EXIT_PC, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, - HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBIO_DIGEST_IDX, - SHARD_ID_IDX, SHARD_RW_SUM_IDX, +use ceno_zkvm::{ + instructions::riscv::constants::{ + END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, EXIT_PC, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, + HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBIO_DIGEST_IDX, + PUBIO_DIGEST_U16_LIMBS, SHARD_ID_IDX, SHARD_RW_SUM_IDX, + }, + structs::VK_DIGEST_LEN, }; use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; use openvm_stark_backend::{ @@ -21,7 +24,11 @@ use stark_recursion_circuit_derive::AlignedBorrow; use crate::{ bus::{LookupChallengeBus, LookupChallengeKind, LookupChallengeMessage}, - circuit::inner::{bus::PvsAirConsistencyBus, vm_pvs::VmPvs}, + circuit::inner::{ + bus::{PvsAirConsistencyBus, PvsAirConsistencyMessage}, + vm_pvs::VmPvs, + }, + utils::TranscriptLabel, }; #[repr(C)] @@ -31,10 +38,14 @@ pub struct VmPvsCols { pub is_valid: F, pub is_last: F, pub has_verifier_pvs: F, + pub vk_digest: [[F; D_EF]; VK_DIGEST_LEN], pub lookup_challenge_alpha: [F; D_EF], pub lookup_challenge_beta: [F; D_EF], pub lookup_challenge_alpha_lookup_count: F, pub lookup_challenge_beta_lookup_count: F, + pub fixed_commit_log2_max_codeword_size: F, + pub fixed_no_omc_init_commit_log2_max_codeword_size: F, + pub witness_commit_log2_max_codeword_size: F, pub child_pvs: VmPvs, } @@ -45,6 +56,8 @@ pub struct VmPvsAir { pub lookup_challenge_bus: LookupChallengeBus, pub pvs_air_consistency_bus: PvsAirConsistencyBus, pub deferral_enabled: bool, + pub has_fixed_commit: bool, + pub has_fixed_no_omc_init_commit: bool, pub instance_public_value_indices: Arc>>, } @@ -188,46 +201,135 @@ impl Air f // local.is_valid * is_leaf, // ); - // Commitments are observed after transcript-visible public values in preflight. - let start_tidx_after_public_value = VmPvs::::width() - 3 * DIGEST_SIZE; - for (didx, value) in local.child_pvs.fixed_commit.iter().enumerate() { + // Mirror the native verifier transcript prefix exactly: + // vk digest, transcript-visible public values, commitments with size + // fields, then lookup challenge samples. + for (tidx, value) in [(0usize, 1668508018u32), (1usize, 118u32)] { self.transcript_bus.receive( builder, local.proof_idx, TranscriptBusMessage { - tidx: AB::Expr::from_usize(start_tidx_after_public_value + didx), - value: (*value).into(), + tidx: AB::Expr::from_usize(tidx), + value: AB::Expr::from_u32(value), is_sample: AB::Expr::ZERO, }, local.is_valid, ); } - for (didx, value) in local.child_pvs.fixed_no_omc_init_commit.iter().enumerate() { - self.transcript_bus.receive( + + let mut transcript_tidx = TranscriptLabel::Riscv.field_len(); + for digest in local.vk_digest { + self.transcript_bus.observe_ext( builder, local.proof_idx, - TranscriptBusMessage { - tidx: AB::Expr::from_usize(start_tidx_after_public_value + DIGEST_SIZE + didx), - value: (*value).into(), - is_sample: AB::Expr::ZERO, - }, + AB::Expr::from_usize(transcript_tidx), + digest, + local.is_valid, + ); + transcript_tidx += D_EF; + } + + for instance_indices in self.instance_public_value_indices.iter() { + for global_pv_idx in instance_indices { + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: AB::Expr::from_usize(transcript_tidx), + value: vm_public_value_by_index::(local, *global_pv_idx), + is_sample: AB::Expr::ZERO, + }, + local.is_valid, + ); + transcript_tidx += 1; + } + } + + if self.has_fixed_commit { + for (didx, value) in local.child_pvs.fixed_commit.iter().enumerate() { + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: AB::Expr::from_usize(transcript_tidx + didx), + value: (*value).into(), + is_sample: AB::Expr::ZERO, + }, + local.is_valid, + ); + } + transcript_tidx += DIGEST_SIZE; + self.transcript_bus.observe( + builder, + local.proof_idx, + AB::Expr::from_usize(transcript_tidx), + local.fixed_commit_log2_max_codeword_size, local.is_valid, ); + transcript_tidx += 1; } + + if self.has_fixed_no_omc_init_commit { + for (didx, value) in local.child_pvs.fixed_no_omc_init_commit.iter().enumerate() { + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: AB::Expr::from_usize(transcript_tidx + didx), + value: (*value).into(), + is_sample: AB::Expr::ZERO, + }, + local.is_valid, + ); + } + transcript_tidx += DIGEST_SIZE; + self.transcript_bus.observe( + builder, + local.proof_idx, + AB::Expr::from_usize(transcript_tidx), + local.fixed_no_omc_init_commit_log2_max_codeword_size, + local.is_valid, + ); + transcript_tidx += 1; + } + for (didx, value) in local.child_pvs.witness_commit.iter().enumerate() { self.transcript_bus.receive( builder, local.proof_idx, TranscriptBusMessage { - tidx: AB::Expr::from_usize( - start_tidx_after_public_value + 2 * DIGEST_SIZE + didx, - ), + tidx: AB::Expr::from_usize(transcript_tidx + didx), value: (*value).into(), is_sample: AB::Expr::ZERO, }, local.is_valid, ); } + transcript_tidx += DIGEST_SIZE; + self.transcript_bus.observe( + builder, + local.proof_idx, + AB::Expr::from_usize(transcript_tidx), + local.witness_commit_log2_max_codeword_size, + local.is_valid, + ); + transcript_tidx += 1; + + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + AB::Expr::from_usize(transcript_tidx), + local.lookup_challenge_alpha, + local.is_valid, + ); + transcript_tidx += D_EF; + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + AB::Expr::from_usize(transcript_tidx), + local.lookup_challenge_beta, + local.is_valid, + ); for i in 0..D_EF { self.lookup_challenge_bus.add_key_with_lookups( @@ -253,15 +355,15 @@ impl Air f } // We look up proof metadata from VerifierPvsAir here to ensure consistency on each row. - // self.pvs_air_consistency_bus.lookup_key( - // builder, - // local.proof_idx, - // PvsAirConsistencyMessage { - // deferral_flag, - // has_verifier_pvs: local.has_verifier_pvs.into(), - // }, - // local.is_valid, - // ); + self.pvs_air_consistency_bus.lookup_key( + builder, + local.proof_idx, + PvsAirConsistencyMessage { + deferral_flag, + has_verifier_pvs: local.has_verifier_pvs.into(), + }, + local.is_valid, + ); // Finally, constrain that this AIR's output public values are consistent with child_pvs. let &VmPvs::<_> { @@ -377,8 +479,9 @@ where { local.child_pvs.shard_rw_sum[idx - SHARD_RW_SUM_IDX].into() } - idx if idx == PUBIO_DIGEST_IDX => local.child_pvs.public_io[0].into(), - idx if idx == PUBIO_DIGEST_IDX + 1 => local.child_pvs.public_io[1].into(), + idx if (PUBIO_DIGEST_IDX..(PUBIO_DIGEST_IDX + PUBIO_DIGEST_U16_LIMBS)).contains(&idx) => { + local.child_pvs.public_io[idx - PUBIO_DIGEST_IDX].into() + } _ => AB::Expr::ZERO, } } diff --git a/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs b/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs index 102876658..da9c73da9 100644 --- a/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs +++ b/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs @@ -1,3 +1,4 @@ +use ceno_zkvm::instructions::riscv::constants::PUBIO_DIGEST_U16_LIMBS; use openvm_stark_backend::{FiatShamirTranscript, TranscriptHistory}; use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, DIGEST_SIZE, F}; use p3_field::PrimeCharacteristicRing; @@ -28,7 +29,7 @@ pub struct VmPvs { pub heap_shard_len: F, pub hint_start_addr: F, pub hint_shard_len: F, - pub public_io: [F; 2], + pub public_io: [F; PUBIO_DIGEST_U16_LIMBS], pub shard_rw_sum: [F; 2 * SEPTIC_EXTENSION_DEGREE], } @@ -44,6 +45,11 @@ pub fn run_preflight( ) where TS: FiatShamirTranscript + TranscriptHistory, { + let vk_digest = child_vk.compute_digest(); + for elem in vk_digest { + ts.observe_ext(elem); + } + // Observe public values in canonical circuit-instance order first. for (_, circuit_vk) in child_vk.circuit_vks.iter() { for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance.iter() { @@ -77,7 +83,6 @@ pub fn run_preflight( let alpha_ext = ts.sample_ext(); let beta_ext = ts.sample_ext(); - eprintln!("vm_pvs alpha {} beta {}", alpha_ext, beta_ext); preflight.vm_pvs.lookup_challenge_alpha = alpha_ext; preflight.vm_pvs.lookup_challenge_beta = beta_ext; preflight.vm_pvs.lookup_challenge_alpha_lookup_count = 0; diff --git a/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs b/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs index a1262f04b..10523e93a 100644 --- a/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs +++ b/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs @@ -1,3 +1,4 @@ +use ceno_zkvm::instructions::riscv::constants::{LIMB_BITS, LIMB_MASK, PUBIO_DIGEST_U16_LIMBS}; use openvm_cpu_backend::CpuBackend; use openvm_stark_backend::prover::AirProvingContext; use openvm_stark_sdk::config::baby_bear_poseidon2::{ @@ -46,6 +47,17 @@ pub fn generate_proving_ctx( .as_ref() .map(|commitment| commitment.commit.clone()), ); + let fixed_commit_log2_max_codeword_size = child_vk + .fixed_commit + .as_ref() + .map(|commitment| commitment.log2_max_codeword_size) + .unwrap_or_default(); + let fixed_no_omc_init_commit_log2_max_codeword_size = child_vk + .fixed_no_omc_init_commit + .as_ref() + .map(|commitment| commitment.log2_max_codeword_size) + .unwrap_or_default(); + let vk_digest = child_vk.compute_digest().map(ef_to_limbs); for (row_idx, row) in trace.chunks_exact_mut(width).enumerate() { let (base_row, def_row) = row.split_at_mut(VmPvsCols::::width()); @@ -58,12 +70,19 @@ pub fn generate_proving_ctx( cols.is_valid = F::ONE; cols.is_last = F::from_bool(row_idx + 1 == proofs.len()); cols.has_verifier_pvs = F::ZERO; + cols.vk_digest = vk_digest; cols.lookup_challenge_alpha = ef_to_limbs(preflight.vm_pvs.lookup_challenge_alpha); cols.lookup_challenge_beta = ef_to_limbs(preflight.vm_pvs.lookup_challenge_beta); cols.lookup_challenge_alpha_lookup_count = F::from_usize(preflight.vm_pvs.lookup_challenge_alpha_lookup_count); cols.lookup_challenge_beta_lookup_count = F::from_usize(preflight.vm_pvs.lookup_challenge_beta_lookup_count); + cols.fixed_commit_log2_max_codeword_size = + F::from_usize(fixed_commit_log2_max_codeword_size); + cols.fixed_no_omc_init_commit_log2_max_codeword_size = + F::from_usize(fixed_no_omc_init_commit_log2_max_codeword_size); + cols.witness_commit_log2_max_codeword_size = + F::from_usize(proof.witin_commit.log2_max_codeword_size); cols.child_pvs = build_vm_pvs(fixed_commit, fixed_no_omc_init_commit, proof); } @@ -109,7 +128,7 @@ fn build_vm_pvs( heap_shard_len: F::from_u32(pv.heap_shard_len), hint_start_addr: F::from_u32(pv.hint_start_addr), hint_shard_len: F::from_u32(pv.hint_shard_len), - public_io: split_u32_lo_hi(pv.public_io_digest[0]), + public_io: split_public_io_digest(pv.public_io_digest), shard_rw_sum: pv.shard_rw_sum.map(F::from_u32), } } @@ -131,6 +150,14 @@ fn split_u32_lo_hi(value: u32) -> [F; 2] { ] } +fn split_public_io_digest(words: [u32; 8]) -> [F; PUBIO_DIGEST_U16_LIMBS] { + core::array::from_fn(|idx| { + let word_idx = idx / 2; + let limb_idx = idx % 2; + F::from_u32((words[word_idx] >> (limb_idx * LIMB_BITS)) & LIMB_MASK) + }) +} + fn ef_to_limbs(value: EF) -> [F; D_EF] { let mut out = [F::ZERO; D_EF]; out.copy_from_slice(value.as_basis_coefficients_slice()); diff --git a/ceno_recursion_v2/src/continuation/prover/inner/mod.rs b/ceno_recursion_v2/src/continuation/prover/inner/mod.rs index a546298bb..b2d4352ea 100644 --- a/ceno_recursion_v2/src/continuation/prover/inner/mod.rs +++ b/ceno_recursion_v2/src/continuation/prover/inner/mod.rs @@ -82,6 +82,8 @@ impl< let circuit = Arc::new(InnerCircuit::new( Arc::new(verifier_circuit), def_hook_commit.map(|d| d.into()), + child_vk.fixed_commit.is_some(), + child_vk.fixed_no_omc_init_commit.is_some(), instance_public_value_indices, )); let (pk, vk) = engine.keygen(&circuit.airs()); @@ -127,6 +129,8 @@ impl< let circuit = Arc::new(InnerCircuit::new( Arc::new(verifier_circuit), def_hook_commit.map(|d| d.into()), + child_vk.fixed_commit.is_some(), + child_vk.fixed_no_omc_init_commit.is_some(), instance_public_value_indices, )); let vk = Arc::new(pk.get_vk()); @@ -153,22 +157,17 @@ impl< } fn build_instance_public_value_indices(child_vk: &RecursionVk) -> Vec> { - (0..child_vk.circuit_vks.len()) - .map(|air_idx| { - child_vk - .circuit_index_to_name - .get(&air_idx) - .and_then(|name| child_vk.circuit_vks.get(name)) - .map(|circuit_vk| { - circuit_vk - .get_cs() - .zkvm_v1_css - .instance - .iter() - .map(|instance_value| instance_value.0) - .collect() - }) - .unwrap_or_default() + child_vk + .circuit_vks + .iter() + .map(|(_, circuit_vk)| { + circuit_vk + .get_cs() + .zkvm_v1_css + .instance + .iter() + .map(|instance_value| instance_value.0) + .collect() }) .collect() } diff --git a/ceno_recursion_v2/src/continuation/prover/mod.rs b/ceno_recursion_v2/src/continuation/prover/mod.rs index a11802970..8c56dd1bf 100644 --- a/ceno_recursion_v2/src/continuation/prover/mod.rs +++ b/ceno_recursion_v2/src/continuation/prover/mod.rs @@ -67,9 +67,8 @@ impl AggregationOptions { } type CenoProof = ZKVMProof>; -type Engine = BabyBearPoseidon2CpuEngine< - openvm_stark_sdk::config::baby_bear_poseidon2::DuplexSponge, ->; +type Engine = + BabyBearPoseidon2CpuEngine; /// Full recursion pipeline that aggregates N Ceno base-layer shard proofs /// into a single compact root proof. @@ -91,9 +90,7 @@ pub struct AggProver { options: AggregationOptions, } -impl - AggProver -{ +impl AggProver { /// Create a new aggregation prover from the base-layer verifying key. pub fn new(child_vk: Arc, options: AggregationOptions) -> Self { let leaf_prover = InnerCpuProver::::new::( diff --git a/ceno_recursion_v2/src/main/air.rs b/ceno_recursion_v2/src/main/air.rs index d94927630..41aa27f71 100644 --- a/ceno_recursion_v2/src/main/air.rs +++ b/ceno_recursion_v2/src/main/air.rs @@ -1,19 +1,18 @@ use core::borrow::Borrow; -use openvm_circuit_primitives::{SubAir, utils::assert_array_eq}; +use openvm_circuit_primitives::utils::assert_array_eq; use openvm_stark_backend::{ BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, }; use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::Field; +use p3_field::{Field, PrimeCharacteristicRing}; use p3_matrix::Matrix; -use recursion_circuit::subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}; use stark_recursion_circuit_derive::AlignedBorrow; use crate::bus::{ MainBus, MainExpressionClaimBus, MainExpressionClaimMessage, MainMessage, MainSumcheckInputBus, - MainSumcheckInputMessage, MainSumcheckOutputBus, MainSumcheckOutputMessage, + MainSumcheckInputMessage, MainSumcheckOutputBus, MainSumcheckOutputMessage, TranscriptBus, }; #[repr(C)] @@ -22,8 +21,11 @@ pub struct MainCols { pub is_enabled: T, pub proof_idx: T, pub idx: T, + pub chip_id: T, pub is_first_idx: T, pub is_first: T, + pub has_tower: T, + pub has_sumcheck: T, pub tidx: T, pub claim_in: [T; D_EF], pub claim_out: [T; D_EF], @@ -34,6 +36,7 @@ pub struct MainAir { pub sumcheck_input_bus: MainSumcheckInputBus, pub sumcheck_output_bus: MainSumcheckOutputBus, pub expression_claim_bus: MainExpressionClaimBus, + pub transcript_bus: TranscriptBus, } impl BaseAir for MainAir { @@ -55,37 +58,51 @@ impl Air for MainAir { let local: &MainCols = (*local_row).borrow(); let next: &MainCols = (*next_row).borrow(); - type LoopSubAir = NestedForLoopSubAir<2>; - LoopSubAir {}.eval( - builder, - ( - NestedForLoopIoCols { - is_enabled: local.is_enabled, - counter: [local.proof_idx, local.idx], - is_first: [local.is_first_idx, local.is_first], - } - .map_into(), - NestedForLoopIoCols { - is_enabled: next.is_enabled, - counter: [next.proof_idx, next.idx], - is_first: [next.is_first_idx, next.is_first], - } - .map_into(), - ), - ); + builder.assert_bool(local.is_enabled); + builder.assert_bool(local.is_first_idx); + builder.assert_bool(local.is_first); + builder.assert_bool(local.has_tower); + builder.assert_bool(local.has_sumcheck); + builder + .when_transition() + .when(AB::Expr::ONE - local.is_enabled) + .assert_zero(next.is_enabled); + builder + .when_first_row() + .when(local.is_enabled) + .assert_zero(local.proof_idx); + builder + .when_first_row() + .when(local.is_enabled) + .assert_one(local.is_first_idx); + + let proof_diff = next.proof_idx - local.proof_idx; + builder + .when_transition() + .when(next.is_enabled) + .assert_bool(proof_diff.clone()); + builder + .when_transition() + .when(next.is_enabled * proof_diff.clone()) + .assert_one(next.is_first_idx); + builder + .when_transition() + .when(next.is_enabled * (AB::Expr::ONE - proof_diff)) + .assert_zero(next.is_first_idx); - let receive_mask = local.is_enabled * local.is_first; + let receive_mask = local.is_enabled * local.is_first * local.has_tower; self.main_bus.receive( builder, local.proof_idx, MainMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), tidx: local.tidx.into(), claim: local.claim_in.map(Into::into), }, receive_mask, ); + let sumcheck_mask = local.is_enabled * local.has_sumcheck; self.sumcheck_input_bus.send( builder, local.proof_idx, @@ -94,7 +111,7 @@ impl Air for MainAir { tidx: local.tidx.into(), claim: local.claim_in.map(Into::into), }, - local.is_enabled, + sumcheck_mask.clone(), ); self.sumcheck_output_bus.receive( @@ -104,7 +121,7 @@ impl Air for MainAir { idx: local.idx.into(), claim: local.claim_out.map(Into::into), }, - local.is_enabled, + sumcheck_mask, ); assert_array_eq( @@ -113,14 +130,20 @@ impl Air for MainAir { local.claim_out, ); - self.expression_claim_bus.send( - builder, - local.proof_idx, + let _ = ( + &self.expression_claim_bus, MainExpressionClaimMessage { idx: local.idx.into(), claim: local.claim_out.map(Into::into), }, - local.is_enabled * local.is_first, + ); + + self.transcript_bus.observe_ext( + builder, + local.proof_idx, + local.tidx.into(), + local.claim_in, + local.is_enabled * local.has_tower, ); } } diff --git a/ceno_recursion_v2/src/main/mod.rs b/ceno_recursion_v2/src/main/mod.rs index 5f0e1aa28..ee16efb0d 100644 --- a/ceno_recursion_v2/src/main/mod.rs +++ b/ceno_recursion_v2/src/main/mod.rs @@ -2,15 +2,14 @@ mod air; mod sumcheck; mod trace; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use ceno_zkvm::scheme::ZKVMChipProof; use eyre::{Result, bail, eyre}; use openvm_cpu_backend::CpuBackend; use openvm_poseidon2_air::POSEIDON2_WIDTH; use openvm_stark_backend::{ - AirRef, FiatShamirTranscript, ReadOnlyTranscript, StarkProtocolConfig, TranscriptHistory, - prover::AirProvingContext, + AirRef, FiatShamirTranscript, StarkProtocolConfig, TranscriptHistory, prover::AirProvingContext, }; use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, EF, F}; use p3_field::PrimeCharacteristicRing; @@ -24,12 +23,14 @@ use self::{ trace::{MainRecord, MainTraceGenerator}, }; use crate::{ - bus::{MainBus, MainExpressionClaimBus, MainSumcheckInputBus, MainSumcheckOutputBus}, + bus::{ + MainBus, MainExpressionClaimBus, MainSumcheckInputBus, MainSumcheckOutputBus, TranscriptBus, + }, system::{ AirModule, BusIndexManager, BusInventory, ChipTranscriptRange, GlobalCtxCpu, Preflight, RecursionField, RecursionProof, RecursionVk, TraceGenModule, }, - tower::convert_logup_claim, + tower::build_gkr_blob, tracegen::{ModuleChip, RowMajorChip}, }; @@ -42,6 +43,7 @@ pub struct MainModule { sumcheck_input_bus: MainSumcheckInputBus, sumcheck_output_bus: MainSumcheckOutputBus, expression_claim_bus: MainExpressionClaimBus, + transcript_bus: TranscriptBus, } impl MainModule { @@ -51,17 +53,19 @@ impl MainModule { let sumcheck_input_bus = bus_inventory.main_sumcheck_input_bus; let sumcheck_output_bus = bus_inventory.main_sumcheck_output_bus; let expression_claim_bus = bus_inventory.main_expression_claim_bus; + let transcript_bus = bus_inventory.transcript_bus; Self { main_bus, sumcheck_input_bus, sumcheck_output_bus, expression_claim_bus, + transcript_bus, } } fn collect_records( &self, - _child_vk: &RecursionVk, + child_vk: &RecursionVk, proofs: &[RecursionProof], preflights: &[Preflight], ) -> Result> { @@ -73,51 +77,68 @@ impl MainModule { ); } + let tower_blob = build_gkr_blob(child_vk, proofs, preflights)?; + let tower_inputs: BTreeMap<(usize, usize), _> = tower_blob + .input_records + .iter() + .map(|record| ((record.proof_idx, record.idx), record)) + .collect(); + let mut paired = Vec::new(); for (proof_idx, (proof, preflight)) in proofs.iter().zip(preflights).enumerate() { - let mut chip_pf_iter = preflight.main.chips.iter(); let mut saw_chip = false; - for (&chip_idx, chip_instances) in &proof.chip_proofs { - for (instance_idx, chip_proof) in chip_instances.iter().enumerate() { - saw_chip = true; - let pf_entry = chip_pf_iter - .next() - .ok_or_else(|| eyre!( - "missing main preflight entry for chip {chip_idx} instance {instance_idx}" - ))?; - if pf_entry.chip_idx != chip_idx || pf_entry.instance_idx != instance_idx { - bail!( - "main preflight chip mismatch: expected ({}, {}), got ({}, {})", - chip_idx, - instance_idx, - pf_entry.chip_idx, - pf_entry.instance_idx - ); - } - let claim = input_layer_claim(chip_proof); - // Access the fork log directly using fork_idx and fork-local tidx. - let fork_log = preflight.fork_log(pf_entry.fork_idx); - let mut ts = ReadOnlyTranscript::new(fork_log, pf_entry.tidx); - record_main_transcript(&mut ts, chip_idx, chip_proof); + let sorted_idx_by_chip: BTreeMap = preflight + .proof_shape + .sorted_trace_vdata + .iter() + .enumerate() + .map(|(sorted_idx, (chip_idx, _))| (*chip_idx, sorted_idx)) + .collect(); + let mut sorted_pf_entries: Vec<_> = preflight.main.chips.iter().collect(); + sorted_pf_entries.sort_by_key(|entry| { + ( + sorted_idx_by_chip + .get(&entry.chip_idx) + .copied() + .unwrap_or(usize::MAX), + entry.instance_idx, + ) + }); - // Compute global tidx for trace column values. - let global_tidx = - preflight.fork_global_offset(pf_entry.fork_idx) + pf_entry.tidx; - let main_record = MainRecord { - proof_idx, - idx: chip_idx, - tidx: global_tidx, - claim, - }; - let sumcheck_record = build_sumcheck_record_from_chip( - proof_idx, - chip_idx, - claim, - chip_proof, - global_tidx, - ); - paired.push((main_record, sumcheck_record)); - } + for (entry_idx, pf_entry) in sorted_pf_entries.into_iter().enumerate() { + let chip_idx = pf_entry.chip_idx; + let instance_idx = pf_entry.instance_idx; + let chip_instances = proof + .chip_proofs + .get(&chip_idx) + .ok_or_else(|| eyre!("missing chip proof instances for chip {chip_idx}"))?; + let chip_proof = chip_instances.get(instance_idx).ok_or_else(|| { + eyre!("missing chip proof instance {instance_idx} for chip {chip_idx}") + })?; + let tower_input = tower_inputs.get(&(proof_idx, entry_idx)).ok_or_else(|| { + eyre!("missing tower input record for proof {proof_idx} idx {entry_idx}") + })?; + saw_chip = true; + + let claim = tower_input.input_layer_claim; + let global_tidx = tower_input.final_tidx; + let sumcheck_record = build_sumcheck_record_from_chip( + proof_idx, + entry_idx, + claim, + chip_proof, + global_tidx, + ); + let main_record = MainRecord { + proof_idx, + idx: entry_idx, + chip_id: chip_idx, + has_tower: tower_input.num_layers > 0, + has_sumcheck: !sumcheck_record.rounds.is_empty(), + tidx: global_tidx, + claim, + }; + paired.push((main_record, sumcheck_record)); } if !saw_chip { @@ -150,6 +171,7 @@ impl AirModule for MainModule { sumcheck_input_bus: self.sumcheck_input_bus, sumcheck_output_bus: self.sumcheck_output_bus, expression_claim_bus: self.expression_claim_bus, + transcript_bus: self.transcript_bus, }; let main_sumcheck_air = MainSumcheckAir { sumcheck_input_bus: self.sumcheck_input_bus, @@ -174,7 +196,7 @@ impl MainModule { for (&chip_idx, chip_instances) in &proof.chip_proofs { for (instance_idx, chip_proof) in chip_instances.iter().enumerate() { let tidx = ts.len(); - record_main_transcript(ts, chip_idx, chip_proof); + record_main_transcript(ts, input_layer_claim(chip_proof)); preflight.main.chips.push(ChipTranscriptRange { chip_idx, instance_idx, @@ -199,22 +221,6 @@ impl> TraceGenModule ) -> Option>>> { let mut paired = self.collect_records(child_vk, proofs, preflights).ok()?; paired.sort_by_key(|(record, _)| (record.proof_idx, record.idx)); - // Replace chip_idx with sequential index per proof_idx group to satisfy - // NestedForLoopSubAir counter constraints (must increment by 0 or 1). - // This matches the tower module which also uses entry_idx. - { - let mut prev_proof_idx = usize::MAX; - let mut seq_idx = 0usize; - for (record, sumcheck_record) in paired.iter_mut() { - if record.proof_idx != prev_proof_idx { - seq_idx = 0; - prev_proof_idx = record.proof_idx; - } - record.idx = seq_idx; - sumcheck_record.idx = seq_idx; - seq_idx += 1; - } - } let (main_records, sumcheck_records): (Vec<_>, Vec<_>) = paired.into_iter().unzip(); let ctx = MainTraceCtx { main_records: &main_records, @@ -268,24 +274,8 @@ impl RowMajorChip for MainModuleChip { } fn input_layer_claim(chip_proof: &ZKVMChipProof) -> EF { - let layer_count = chip_proof - .tower_proof - .logup_specs_eval - .iter() - .map(|spec_layers| spec_layers.len()) - .chain( - chip_proof - .tower_proof - .prod_specs_eval - .iter() - .map(|spec_layers| spec_layers.len()), - ) - .max() - .unwrap_or(0); - if layer_count == 0 { - return EF::ZERO; - } - convert_logup_claim(chip_proof, layer_count - 1)[0] + let _ = chip_proof; + EF::ZERO } fn build_sumcheck_record_from_chip( @@ -325,12 +315,9 @@ fn build_sumcheck_record_from_chip( } } -pub(crate) fn record_main_transcript( - ts: &mut TS, - _chip_idx: usize, - chip_proof: &ZKVMChipProof, -) where +pub(crate) fn record_main_transcript(ts: &mut TS, claim: EF) +where TS: FiatShamirTranscript, { - ts.observe_ext(input_layer_claim(chip_proof)); + ts.observe_ext(claim); } diff --git a/ceno_recursion_v2/src/main/sumcheck/air.rs b/ceno_recursion_v2/src/main/sumcheck/air.rs index b9d0f139c..6ced478e8 100644 --- a/ceno_recursion_v2/src/main/sumcheck/air.rs +++ b/ceno_recursion_v2/src/main/sumcheck/air.rs @@ -1,6 +1,6 @@ use core::borrow::Borrow; -use openvm_circuit_primitives::{SubAir, utils::assert_array_eq}; +use openvm_circuit_primitives::utils::assert_array_eq; use openvm_stark_backend::{ BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, }; @@ -8,12 +8,9 @@ use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{Field, PrimeCharacteristicRing, extension::BinomiallyExtendable}; use p3_matrix::Matrix; -use recursion_circuit::{ - subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, - utils::{ - assert_one_ext, ext_field_add, ext_field_multiply, ext_field_multiply_scalar, - ext_field_one_minus, ext_field_subtract, - }, +use recursion_circuit::utils::{ + assert_one_ext, ext_field_add, ext_field_multiply, ext_field_multiply_scalar, + ext_field_one_minus, ext_field_subtract, }; use stark_recursion_circuit_derive::AlignedBorrow; @@ -76,34 +73,53 @@ where builder.assert_bool(local.is_dummy); builder.assert_bool(local.is_last_round); builder.assert_bool(local.is_first_round); + builder.assert_bool(local.is_enabled); + builder.assert_bool(local.is_first_idx); + builder.assert_bool(next.is_first_idx); + builder.assert_bool(next.is_first_round); + builder + .when_transition() + .when(AB::Expr::ONE - local.is_enabled) + .assert_zero(next.is_enabled); + builder + .when_first_row() + .when(local.is_enabled) + .assert_zero(local.proof_idx); + builder + .when_first_row() + .when(local.is_enabled) + .assert_one(local.is_first_idx); + builder + .when(local.is_first_idx) + .assert_one(local.is_first_round); - type LoopSubAir = NestedForLoopSubAir<2>; - LoopSubAir {}.eval( - builder, - ( - NestedForLoopIoCols { - is_enabled: local.is_enabled, - counter: [local.proof_idx, local.idx], - is_first: [local.is_first_idx, local.is_first_round], - } - .map_into(), - NestedForLoopIoCols { - is_enabled: next.is_enabled, - counter: [next.proof_idx, next.idx], - is_first: [next.is_first_idx, next.is_first_round], - } - .map_into(), - ), - ); + let proof_diff = next.proof_idx - local.proof_idx; + builder + .when_transition() + .when(next.is_enabled) + .assert_bool(proof_diff.clone()); + builder + .when_transition() + .when(next.is_enabled * proof_diff.clone()) + .assert_one(next.is_first_idx); + builder + .when_transition() + .when(next.is_enabled * (AB::Expr::ONE - proof_diff)) + .assert_zero(next.is_first_idx); - let is_transition_round = - LoopSubAir::local_is_transition(next.is_enabled, next.is_first_round); + let is_transition_round = next.is_enabled * (AB::Expr::ONE - next.is_first_round); let computed_is_last = - LoopSubAir::local_is_last(local.is_enabled, next.is_enabled, next.is_first_round); + local.is_enabled * (AB::Expr::ONE - next.is_enabled + next.is_first_round); builder .when(local.is_enabled) .assert_eq(local.is_last_round, computed_is_last.clone()); + builder + .when(is_transition_round.clone()) + .assert_eq(next.proof_idx, local.proof_idx); + builder + .when(is_transition_round.clone()) + .assert_eq(next.idx, local.idx); builder.when(local.is_first_round).assert_zero(local.round); builder diff --git a/ceno_recursion_v2/src/main/trace.rs b/ceno_recursion_v2/src/main/trace.rs index 5d7e36e53..d8a8dd498 100644 --- a/ceno_recursion_v2/src/main/trace.rs +++ b/ceno_recursion_v2/src/main/trace.rs @@ -11,6 +11,9 @@ use crate::tracegen::RowMajorChip; pub struct MainRecord { pub proof_idx: usize, pub idx: usize, + pub chip_id: usize, + pub has_tower: bool, + pub has_sumcheck: bool, pub tidx: usize, pub claim: EF, } @@ -88,8 +91,11 @@ fn fill_main_cols(record: &MainRecord, cols: &mut MainCols, is_first_idx: boo cols.is_enabled = F::ONE; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.is_first_idx = F::from_bool(is_first_idx); cols.is_first = F::from_bool(is_first); + cols.has_tower = F::from_bool(record.has_tower); + cols.has_sumcheck = F::from_bool(record.has_sumcheck); cols.tidx = F::from_usize(record.tidx); let claim_basis: [F; D_EF] = record .claim diff --git a/ceno_recursion_v2/src/proof_shape/mod.rs b/ceno_recursion_v2/src/proof_shape/mod.rs index cf931b2b4..b1745130a 100644 --- a/ceno_recursion_v2/src/proof_shape/mod.rs +++ b/ceno_recursion_v2/src/proof_shape/mod.rs @@ -164,7 +164,17 @@ impl ProofShapeModule { .chip_proofs .values() .flat_map(|instances| instances.iter()) - .map(|chip_proof| chip_proof.tower_proof.proofs.len()) + .map(|chip_proof| { + let proof_layers = chip_proof.tower_proof.proofs.len(); + let has_root_specs = !chip_proof.r_out_evals.is_empty() + || !chip_proof.w_out_evals.is_empty() + || !chip_proof.lk_out_evals.is_empty(); + if proof_layers == 0 && !has_root_specs { + 0 + } else { + proof_layers + 1 + } + }) .max() .unwrap_or(0); @@ -254,6 +264,7 @@ impl AirModule for ProofShapeModule { fraction_folder_input_bus: self.bus_inventory.fraction_folder_input_bus, expression_claim_n_max_bus: self.bus_inventory.expression_claim_n_max_bus, tower_module_bus: self.bus_inventory.tower_module_bus, + tower_root_claim_bus: self.bus_inventory.tower_root_claim_bus, air_shape_bus: self.bus_inventory.air_shape_bus, hyperdim_bus: self.bus_inventory.hyperdim_bus, lifted_heights_bus: self.bus_inventory.lifted_heights_bus, diff --git a/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs b/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs index 853848d02..ebc5abad2 100644 --- a/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs +++ b/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs @@ -4,7 +4,7 @@ use itertools::fold; use openvm_circuit_primitives::{ SubAir, encoder::Encoder, - utils::{and, not, or}, + utils::{and, not}, }; use openvm_stark_backend::{ BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, @@ -21,10 +21,9 @@ use crate::{ ForkedTranscriptBus, ForkedTranscriptBusMessage, FractionFolderInputBus, FractionFolderInputMessage, HyperdimBus, HyperdimBusMessage, LiftedHeightsBus, LiftedHeightsBusMessage, LookupChallengeBus, LookupChallengeKind, LookupChallengeMessage, - NLiftBus, NLiftMessage, TowerModuleBus, TowerModuleMessage, TranscriptBus, - TranscriptBusMessage, + NLiftBus, NLiftMessage, TowerModuleBus, TowerModuleMessage, TowerRootClaimBus, + TowerRootClaimMessage, TranscriptBus, TranscriptBusMessage, }, - circuit::inner::vm_pvs::VmPvs, primitives::bus::{RangeCheckerBus, RangeCheckerBusMessage}, proof_shape::{ AirMetadata, @@ -53,6 +52,8 @@ pub struct ProofShapeCols { pub log_height: F, /// Whether this AIR needs rotation openings. pub need_rot: F, + /// Number of tower reduction layers for this chip proof. + pub num_tower_layers: F, // First possible transcript index of the current AIR. pub starting_tidx: F, @@ -68,6 +69,8 @@ pub struct ProofShapeCols { // Number of present AIRs so far pub num_present: F, + /// Native fork transcript id for this chip proof. + pub fork_id: F, /// Limb decomposition of per-instance heights used for range/decomposition checks. pub height_1_limbs: [F; NUM_LIMBS], @@ -85,8 +88,16 @@ pub struct ProofShapeCols { /// forking). Constrained to be identical across all rows within a proof. pub lookup_challenge_alpha: [F; D_EF], pub lookup_challenge_beta: [F; D_EF], - pub after_forked_challenge_1: [F; D_EF], - pub after_forked_challenge_2: [F; D_EF], + /// Fork-local tidx of the final fork sample merged back into the trunk transcript. + pub fork_sample_tidx: F, + /// Trunk tidx where this fork's final sample is observed during merge. + pub merge_tidx: F, + pub fork_merge_sample: [F; D_EF], + + pub r0_claim: [F; D_EF], + pub w0_claim: [F; D_EF], + pub p0_claim: [F; D_EF], + pub q0_claim: [F; D_EF], } // Variable-length columns are stored at the end @@ -113,6 +124,7 @@ pub struct ProofShapeAir { // Inter-module buses pub tower_module_bus: TowerModuleBus, + pub tower_root_claim_bus: TowerRootClaimBus, pub air_shape_bus: AirShapeBus, pub expression_claim_n_max_bus: ExpressionClaimNMaxBus, pub fraction_folder_input_bus: FractionFolderInputBus, @@ -157,6 +169,7 @@ where let local: &ProofShapeCols = (*local)[..const_width].borrow(); let next: &ProofShapeCols = (*next)[..const_width].borrow(); let n = local.log_height.into(); + let num_tower_layers: AB::Expr = local.num_tower_layers.into(); self.idx_encoder.eval(builder, localv.idx_flags); @@ -204,11 +217,11 @@ where // Select values for NumPublicValuesBus let mut num_pvs = AB::Expr::ZERO; + // Per-selected-air tower transcript span (used for fork challenge tidx bump). + let mut tower_tidx_bump = AB::Expr::ZERO; let mut num_read_count = AB::Expr::ZERO; let mut num_write_count = AB::Expr::ZERO; let mut num_logup_count = AB::Expr::ZERO; - // Per-selected-air tower transcript span (used for fork challenge tidx bump). - let mut tower_tidx_bump = AB::Expr::ZERO; for (i, air_data) in self.per_air.iter().enumerate() { // We keep a running tally of how many transcript reads there should be up to any @@ -222,25 +235,21 @@ where num_fixed += is_current_air.clone() * AB::F::from_usize(air_data.num_fixed); num_pvs += is_current_air.clone() * AB::F::from_usize(air_data.num_public_values); + num_read_count += is_current_air.clone() * AB::F::from_usize(air_data.num_read_count); + num_write_count += is_current_air.clone() * AB::F::from_usize(air_data.num_write_count); + num_logup_count += is_current_air.clone() * AB::F::from_usize(air_data.num_logup_count); if air_data.is_required { is_required += is_current_air.clone(); when_current.assert_one(local.is_present); } - num_read_count += - is_current_air.clone() * AB::Expr::from_usize(air_data.num_read_count); - num_write_count += - is_current_air.clone() * AB::Expr::from_usize(air_data.num_write_count); - num_logup_count += - is_current_air.clone() * AB::Expr::from_usize(air_data.num_logup_count); - // Keep this aligned with TowerInputAir's `tidx_after_gkr_layers` // arithmetic so fork challenge placement and tower buses share one // transcript span model. tower_tidx_bump += is_current_air * per_air_tower_span::( - n.clone(), + num_tower_layers.clone(), air_data.num_read_count, air_data.num_write_count, air_data.num_logup_count, @@ -290,7 +299,7 @@ where value: local.log_height - next.log_height, max_bits: AB::Expr::from_usize(5), }, - and(local.is_valid, not(next.is_last)), + AB::Expr::ZERO, ); /////////////////////////////////////////////////////////////////////////////////////////// @@ -305,7 +314,7 @@ where word_idx: AB::Expr::from_usize(i), value: local.lookup_challenge_alpha[i].into(), }, - local.is_present * local.is_valid, + AB::Expr::ZERO, ); } for i in 0..D_EF { @@ -317,7 +326,7 @@ where word_idx: AB::Expr::from_usize(i), value: local.lookup_challenge_beta[i].into(), }, - local.is_present * local.is_valid, + AB::Expr::ZERO, ); } @@ -325,15 +334,6 @@ where // TRANSCRIPT OBSERVATIONS /////////////////////////////////////////////////////////////////////////////////////////// - let is_first_idx = self.idx_encoder.get_flag_expr::(0, localv.idx_flags); - - // The first AIR starts immediately after the fixed trunk transcript prefix. - builder.when(is_first_idx.clone()).assert_eq( - local.starting_tidx, - AB::Expr::from_usize(TranscriptLabel::Riscv.field_len() + VmPvs::::width()) - + AB::Expr::from_usize(2 * D_EF), - ); - self.starting_tidx_bus.receive( builder, local.proof_idx, @@ -342,44 +342,10 @@ where + AB::Expr::from_usize(self.per_air.len()) * local.is_last, tidx: local.starting_tidx.into(), }, - or( - local.is_last, - and(local.is_valid, not::(is_first_idx)), - ), + AB::Expr::ZERO, ); - // Challenges are laid out in trunk transcript as contiguous EF limbs per present AIR. - // We jump directly to this AIR's segment using num_present (1-based among present AIRs). - let mut tidx = - local.starting_tidx.into() + local.num_present * AB::Expr::from_usize(2 * D_EF); - - for i in 0..D_EF { - self.transcript_bus.receive( - builder, - local.proof_idx, - TranscriptBusMessage { - tidx: tidx.clone() + AB::Expr::from_usize(i), - value: local.after_forked_challenge_1[i].into(), - is_sample: AB::Expr::ZERO, - }, - local.is_present, - ); - } - tidx += AB::Expr::from_usize(D_EF) * local.is_present; - - for i in 0..D_EF { - self.transcript_bus.receive( - builder, - local.proof_idx, - TranscriptBusMessage { - tidx: tidx.clone() + AB::Expr::from_usize(i), - value: local.after_forked_challenge_2[i].into(), - is_sample: AB::Expr::ZERO, - }, - local.is_present, - ); - } - tidx += AB::Expr::from_usize(D_EF) * local.is_present; + let tidx = local.starting_tidx + local.is_present * tower_tidx_bump; // constrain next air tid self.starting_tidx_bus.send( @@ -389,7 +355,7 @@ where air_idx: air_idx.clone() + AB::F::ONE, tidx, }, - local.is_valid, + AB::Expr::ZERO, ); /////////////////////////////////////////////////////////////////////////////////////////// @@ -402,7 +368,18 @@ where // chips would require a separate fork_id column. // Receive fork transcript words after the fork label prefix. let fork_tidx_base = TranscriptLabel::Fork.field_len(); - let fork_id = local.num_present - AB::F::ONE; + let fork_id = local.fork_id.into(); + self.forked_transcript_bus.receive( + builder, + local.proof_idx, + ForkedTranscriptBusMessage { + fork_id: fork_id.clone(), + tidx: AB::Expr::ZERO, + value: AB::Expr::from_usize(1802661734), // u32::from_le_bytes(*b"fork") + is_sample: AB::Expr::ZERO, + }, + local.is_present * local.is_valid, + ); // observe lookup alpha/beta for i in 0..D_EF { self.forked_transcript_bus.receive( @@ -457,20 +434,22 @@ where ForkedTranscriptBusMessage { fork_id: fork_id.clone().into(), tidx: AB::Expr::from_usize(fork_tidx_base + 2 * D_EF + 2), - value: local.log_height.into(), + value: local.height_1.into(), + is_sample: AB::Expr::ZERO, + }, + local.is_present * local.is_valid, + ); + self.forked_transcript_bus.receive( + builder, + local.proof_idx, + ForkedTranscriptBusMessage { + fork_id: fork_id.clone().into(), + tidx: AB::Expr::from_usize(fork_tidx_base + 2 * D_EF + 3), + value: local.height_2.into(), is_sample: AB::Expr::ZERO, }, local.is_present * local.is_valid, ); - - // Skip the full per-air tower transcript span (out-evals, alpha/beta, - // and all GKR/sumcheck layer transcript activity) before binding the - // post-fork sampled challenges. - let forked_challenge_1_tidx = - AB::Expr::from_usize(fork_tidx_base + 2 * D_EF + 3) + tower_tidx_bump; - // Challenge 2 starts after challenge 1 plus the product_sum label span. - let forked_challenge_2_tidx = - forked_challenge_1_tidx.clone() + AB::Expr::from_usize(tower_transcript_len::BETA_LEN); for i in 0..D_EF { self.forked_transcript_bus.receive( @@ -478,20 +457,19 @@ where local.proof_idx, ForkedTranscriptBusMessage { fork_id: fork_id.clone().into(), - tidx: forked_challenge_1_tidx.clone() + AB::Expr::from_usize(i), - value: local.after_forked_challenge_1[i].into(), + tidx: local.fork_sample_tidx.into() + AB::Expr::from_usize(i), + value: local.fork_merge_sample[i].into(), is_sample: AB::Expr::ONE, }, local.is_present * local.is_valid, ); - self.forked_transcript_bus.receive( + self.transcript_bus.receive( builder, local.proof_idx, - ForkedTranscriptBusMessage { - fork_id: fork_id.clone().into(), - tidx: forked_challenge_2_tidx.clone() + AB::Expr::from_usize(i), - value: local.after_forked_challenge_2[i].into(), - is_sample: AB::Expr::ONE, + TranscriptBusMessage { + tidx: local.merge_tidx.into() + AB::Expr::from_usize(i), + value: local.fork_merge_sample[i].into(), + is_sample: AB::Expr::ZERO, }, local.is_present * local.is_valid, ); @@ -508,7 +486,7 @@ where property_idx: AirShapeProperty::AirId.to_field(), value: air_idx.clone(), }, - local.is_present * local.num_air_id_lookups, + AB::Expr::ZERO, ); self.air_shape_bus.add_key_with_lookups( @@ -519,7 +497,7 @@ where property_idx: AirShapeProperty::NumInteractions.to_field(), value: AB::Expr::ZERO, }, - local.is_present, + AB::Expr::ZERO, ); self.air_shape_bus.add_key_with_lookups( @@ -530,40 +508,8 @@ where property_idx: AirShapeProperty::NeedRot.to_field(), value: local.need_rot.into(), }, - local.is_present * local.num_columns, - ); - self.air_shape_bus.add_key_with_lookups( - builder, - local.proof_idx, - AirShapeBusMessage { - sort_idx: local.sorted_idx.into(), - property_idx: AirShapeProperty::NumRead.to_field(), - value: num_read_count.clone(), - }, - // each layer lookup once if current air was present - local.is_present * n.clone(), - ); - self.air_shape_bus.add_key_with_lookups( - builder, - local.proof_idx, - AirShapeBusMessage { - sort_idx: local.sorted_idx.into(), - property_idx: AirShapeProperty::NumWrite.to_field(), - value: num_write_count.clone(), - }, - local.is_present * n.clone(), - ); - self.air_shape_bus.add_key_with_lookups( - builder, - local.proof_idx, - AirShapeBusMessage { - sort_idx: local.sorted_idx.into(), - property_idx: AirShapeProperty::NumLk.to_field(), - value: num_logup_count, - }, - local.is_present * n.clone(), + AB::Expr::ZERO, ); - /////////////////////////////////////////////////////////////////////////////////////////// // HYPERDIM LOOKUP /////////////////////////////////////////////////////////////////////////////////////////// @@ -581,7 +527,7 @@ where value: n.clone(), max_bits: AB::Expr::from_usize(5), }, - local.is_present, + AB::Expr::ZERO, ); self.hyperdim_bus.add_key_with_lookups( @@ -592,7 +538,7 @@ where n_abs: n.clone(), n_sign_bit: AB::Expr::ZERO, }, - local.is_present * (local.num_air_id_lookups + AB::F::ONE), + AB::Expr::ZERO, ); /////////////////////////////////////////////////////////////////////////////////////////// @@ -622,7 +568,7 @@ where lifted_height: combined_height.into(), log_lifted_height: local.log_height.into(), }, - local.is_present * (num_witin + num_structural_witin + num_fixed), + AB::Expr::ZERO, ); /////////////////////////////////////////////////////////////////////////////////////////// @@ -640,7 +586,7 @@ where value: local.height_1_limbs[i].into(), max_bits: AB::Expr::from_usize(LIMB_BITS), }, - local.is_valid, + AB::Expr::ZERO, ); self.range_bus.lookup_key( builder, @@ -648,7 +594,7 @@ where value: local.height_2_limbs[i].into(), max_bits: AB::Expr::from_usize(LIMB_BITS), }, - local.is_valid, + AB::Expr::ZERO, ); } @@ -686,18 +632,33 @@ where * (local.is_n_max_greater * AB::F::TWO - AB::F::ONE), max_bits: AB::Expr::from_usize(5), }, - local.is_last, + AB::Expr::ZERO, ); self.tower_module_bus.send( builder, local.proof_idx, TowerModuleMessage { - idx: air_idx.clone(), - tidx: local.starting_tidx.into(), - n_logup: n, + chip_id: air_idx.clone(), + num_layers: num_tower_layers, + num_read_specs: num_read_count.clone(), + num_write_specs: num_write_count.clone(), + num_logup_specs: num_logup_count.clone(), + }, + local.is_present * local.is_valid, + ); + + self.tower_root_claim_bus.receive( + builder, + local.proof_idx, + TowerRootClaimMessage { + chip_id: air_idx.clone(), + r0_claim: local.r0_claim.map(Into::into), + w0_claim: local.w0_claim.map(Into::into), + p0_claim: local.p0_claim.map(Into::into), + q0_claim: local.q0_claim.map(Into::into), }, - local.is_last, + local.is_present * local.is_valid, ); // Send n_max value to expression claim air @@ -707,7 +668,7 @@ where ExpressionClaimNMaxMessage { n_max: local.n_max.into(), }, - local.is_last, + AB::Expr::ZERO, ); // Send n_lift to constraint folding air @@ -718,7 +679,7 @@ where air_idx: air_idx, n_lift: local.log_height.into(), }, - local.is_present, + AB::Expr::ZERO, ); // Send count of present airs to fraction folder air @@ -728,7 +689,7 @@ where FractionFolderInputMessage { num_present_airs: local.num_present, }, - local.is_last, + AB::Expr::ZERO, ); } } diff --git a/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs b/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs index 444e78e4d..b5fe8fce0 100644 --- a/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs +++ b/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs @@ -8,7 +8,7 @@ use p3_matrix::dense::RowMajorMatrix; use super::air::ProofShapeCols; use crate::{ primitives::{pow::PowerCheckerCpuTraceGenerator, range::RangeCheckerCpuTraceGenerator}, - system::{POW_CHECKER_HEIGHT, Preflight, RecursionProof, RecursionVk}, + system::{POW_CHECKER_HEIGHT, Preflight, RecursionField, RecursionProof, RecursionVk}, tracegen::RowMajorChip, }; @@ -22,6 +22,52 @@ fn borrow_var_cols_mut(slice: &mut [F], idx_flags: usize) -> ProofShapeVarCol } } +fn root_claims_for_chip(proof: &RecursionProof, air_idx: usize) -> (EF, EF, EF, EF) { + let Some(chip_proof) = proof + .chip_proofs + .get(&air_idx) + .and_then(|instances| instances.first()) + else { + return (EF::ZERO, EF::ZERO, EF::ZERO, EF::ZERO); + }; + + let r0 = chip_proof + .r_out_evals + .iter() + .map(|evals| evals[0] * evals[1]) + .product::(); + let w0 = chip_proof + .w_out_evals + .iter() + .map(|evals| evals[0] * evals[1]) + .product::(); + let mut p0 = EF::ZERO; + let mut q0 = EF::ONE; + for evals in &chip_proof.lk_out_evals { + let p_cross = evals[0] * evals[3] + evals[1] * evals[2]; + let q_cross = evals[2] * evals[3]; + p0 = p0 * q_cross + p_cross * q0; + q0 *= q_cross; + } + (r0, w0, p0, q0) +} + +fn assign_ext(dst: &mut [F; D_EF], value: EF) { + dst.copy_from_slice(value.as_basis_coefficients_slice()); +} + +fn tower_layer_count(chip_proof: &ceno_zkvm::scheme::ZKVMChipProof) -> usize { + let proof_layers = chip_proof.tower_proof.proofs.len(); + let has_root_specs = !chip_proof.r_out_evals.is_empty() + || !chip_proof.w_out_evals.is_empty() + || !chip_proof.lk_out_evals.is_empty(); + if proof_layers == 0 && !has_root_specs { + 0 + } else { + proof_layers + 1 + } +} + fn decompose_usize( mut value: usize, ) -> [usize; NUM_LIMBS] { @@ -47,6 +93,17 @@ fn two_instance_heights_from_chip_instances( }) } +fn fork_merge_sample(preflight: &Preflight, fork_id: usize) -> Option<(usize, [F; D_EF])> { + let fork_log = preflight + .fork_transcripts + .iter() + .find(|fork| fork.fork_id == fork_id)?; + let sample_tidx = fork_log.log.len().checked_sub(D_EF)?; + let mut sample = [F::ZERO; D_EF]; + sample.copy_from_slice(&fork_log.log.values()[sample_tidx..sample_tidx + D_EF]); + Some((sample_tidx, sample)) +} + trait BorrowNumInstances { fn borrow_num_instances(&self) -> &[usize]; } @@ -100,6 +157,18 @@ impl RowMajorChip let mut chunks = trace.chunks_exact_mut(width); for (proof_idx, (proof, preflight)) in proofs.iter().zip(preflights.iter()).enumerate() { + let fork_id_by_chip: std::collections::BTreeMap = proof + .chip_proofs + .iter() + .flat_map(|(chip_idx, instances)| { + instances + .iter() + .enumerate() + .map(move |(instance_idx, _)| (*chip_idx, instance_idx)) + }) + .enumerate() + .map(|(fork_id, (chip_idx, _instance_idx))| (chip_idx, fork_id)) + .collect(); let mut sorted_idx = 0usize; let mut num_present = 0usize; @@ -124,11 +193,23 @@ impl RowMajorChip cols.sorted_idx = F::from_usize(sorted_idx); cols.log_height = F::from_usize(log_height); cols.need_rot = F::ZERO; + cols.num_tower_layers = F::from_usize( + proof + .chip_proofs + .get(air_idx) + .and_then(|instances| instances.iter().map(tower_layer_count).max()) + .unwrap_or(0), + ); cols.starting_tidx = F::from_usize(preflight.proof_shape.starting_tidx[*air_idx]); cols.is_present = F::ONE; cols.height_1 = F::from_usize(height_1); cols.height_2 = F::from_usize(height_2); cols.num_present = F::from_usize(num_present); + let fork_id = fork_id_by_chip + .get(air_idx) + .copied() + .unwrap_or(num_present.saturating_sub(1)); + cols.fork_id = F::from_usize(fork_id); cols.height_1_limbs = decompose_usize::(height_1).map(F::from_usize); cols.height_2_limbs = @@ -139,10 +220,17 @@ impl RowMajorChip cols.num_columns = F::ZERO; cols.lookup_challenge_alpha = preflight.proof_shape.lookup_challenge_alpha; cols.lookup_challenge_beta = preflight.proof_shape.lookup_challenge_beta; - cols.after_forked_challenge_1 = - ef_to_limbs(preflight.proof_shape.after_forked_challenge_1); - cols.after_forked_challenge_2 = - ef_to_limbs(preflight.proof_shape.after_forked_challenge_2); + if let Some((sample_tidx, sample)) = fork_merge_sample(preflight, fork_id) { + cols.fork_sample_tidx = F::from_usize(sample_tidx); + cols.merge_tidx = + F::from_usize(preflight.proof_shape.fork_start_tidx + fork_id * D_EF); + cols.fork_merge_sample = sample; + } + let (r0, w0, p0, q0) = root_claims_for_chip(proof, *air_idx); + assign_ext(&mut cols.r0_claim, r0); + assign_ext(&mut cols.w0_claim, w0); + assign_ext(&mut cols.p0_claim, p0); + assign_ext(&mut cols.q0_claim, q0); for (dst, src) in var_cols .idx_flags @@ -152,7 +240,6 @@ impl RowMajorChip *dst = F::from_u32(*src); } - self.pow_checker.add_pow(log_height); sorted_idx += 1; } @@ -172,11 +259,13 @@ impl RowMajorChip cols.sorted_idx = F::from_usize(sorted_idx); cols.log_height = F::ZERO; cols.need_rot = F::ZERO; + cols.num_tower_layers = F::ZERO; cols.starting_tidx = F::from_usize(preflight.proof_shape.starting_tidx[air_idx]); cols.is_present = F::ZERO; cols.height_1 = F::ZERO; cols.height_2 = F::ZERO; cols.num_present = F::from_usize(num_present); + cols.fork_id = F::ZERO; cols.height_1_limbs = [F::ZERO; NUM_LIMBS]; cols.height_2_limbs = [F::ZERO; NUM_LIMBS]; cols.n_max = F::from_usize(preflight.proof_shape.n_max); @@ -185,10 +274,9 @@ impl RowMajorChip cols.num_columns = F::ZERO; cols.lookup_challenge_alpha = preflight.proof_shape.lookup_challenge_alpha; cols.lookup_challenge_beta = preflight.proof_shape.lookup_challenge_beta; - cols.after_forked_challenge_1 = - ef_to_limbs(preflight.proof_shape.after_forked_challenge_1); - cols.after_forked_challenge_2 = - ef_to_limbs(preflight.proof_shape.after_forked_challenge_2); + cols.fork_sample_tidx = F::ZERO; + cols.merge_tidx = F::ZERO; + cols.fork_merge_sample = [F::ZERO; D_EF]; for (dst, src) in var_cols .idx_flags @@ -212,11 +300,13 @@ impl RowMajorChip cols.sorted_idx = F::ZERO; cols.log_height = F::from_usize(preflight.proof_shape.n_logup); cols.need_rot = F::ZERO; + cols.num_tower_layers = F::from_usize(preflight.proof_shape.n_logup); cols.starting_tidx = F::from_usize(preflight.proof_shape.post_tidx); cols.is_present = F::ZERO; cols.height_1 = F::ZERO; cols.height_2 = F::ZERO; cols.num_present = F::from_usize(num_present); + cols.fork_id = F::ZERO; cols.height_1_limbs = [F::ZERO; NUM_LIMBS]; cols.height_2_limbs = [F::ZERO; NUM_LIMBS]; cols.n_max = F::from_usize(preflight.proof_shape.n_max); @@ -226,10 +316,9 @@ impl RowMajorChip cols.num_columns = F::ZERO; cols.lookup_challenge_alpha = preflight.proof_shape.lookup_challenge_alpha; cols.lookup_challenge_beta = preflight.proof_shape.lookup_challenge_beta; - cols.after_forked_challenge_1 = - ef_to_limbs(preflight.proof_shape.after_forked_challenge_1); - cols.after_forked_challenge_2 = - ef_to_limbs(preflight.proof_shape.after_forked_challenge_2); + cols.fork_sample_tidx = F::ZERO; + cols.merge_tidx = F::ZERO; + cols.fork_merge_sample = [F::ZERO; D_EF]; } for chunk in chunks { @@ -240,9 +329,3 @@ impl RowMajorChip Some(RowMajorMatrix::new(trace, width)) } } - -fn ef_to_limbs(value: EF) -> [F; D_EF] { - let mut out = [F::ZERO; D_EF]; - out.copy_from_slice(value.as_basis_coefficients_slice()); - out -} diff --git a/ceno_recursion_v2/src/proof_shape/pvs/air.rs b/ceno_recursion_v2/src/proof_shape/pvs/air.rs index 7626001b4..74baea306 100644 --- a/ceno_recursion_v2/src/proof_shape/pvs/air.rs +++ b/ceno_recursion_v2/src/proof_shape/pvs/air.rs @@ -8,9 +8,12 @@ use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{PrimeCharacteristicRing, PrimeField32}; use p3_matrix::Matrix; +use ceno_zkvm::structs::VK_DIGEST_LEN; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; + use crate::{ - bus::{PublicValuesBus, PublicValuesBusMessage, TranscriptBus, TranscriptBusMessage}, - proof_shape::bus::{NumPublicValuesBus, NumPublicValuesMessage}, + bus::{PublicValuesBus, PublicValuesBusMessage, TranscriptBus}, + proof_shape::bus::NumPublicValuesBus, subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, utils::TranscriptLabel, }; @@ -90,12 +93,12 @@ where .assert_one(next.is_first_in_air); let is_same_air = local.is_valid * next.is_valid * not(next.is_first_in_air); - // TODO fix first tidx to be TranscriptLabel::Riscv.field_len() - // TODO fix comment as well - // first tidx happened here + let first_public_value_tidx = + AB::Expr::from_usize(TranscriptLabel::Riscv.field_len() + VK_DIGEST_LEN * D_EF); + builder .when(local.is_valid * local.is_first_in_proof * local.is_first_in_air) - .assert_zero(local.tidx); + .assert_eq(local.tidx, first_public_value_tidx); // self.num_pvs_bus.receive( // builder, @@ -123,29 +126,9 @@ where }, local.is_valid, ); - if self.continuations_enabled { - self.public_values_bus.send( - builder, - local.proof_idx, - PublicValuesBusMessage { - air_idx: local.air_idx, - pv_idx: local.pv_idx, - value: local.value, - }, - local.is_valid, - ); - } - - // Receive transcript read of public values - self.transcript_bus.receive( - builder, - local.proof_idx, - TranscriptBusMessage { - tidx: local.tidx.into(), - value: local.value.into(), - is_sample: AB::Expr::ZERO, - }, - local.is_valid, - ); + let _ = self.continuations_enabled; + + // VmPvsAir owns the verifier transcript prefix, including these public values. + let _ = &self.transcript_bus; } } diff --git a/ceno_recursion_v2/src/proof_shape/pvs/trace.rs b/ceno_recursion_v2/src/proof_shape/pvs/trace.rs index 13e6b645b..a131ac313 100644 --- a/ceno_recursion_v2/src/proof_shape/pvs/trace.rs +++ b/ceno_recursion_v2/src/proof_shape/pvs/trace.rs @@ -1,6 +1,7 @@ use core::borrow::BorrowMut; -use openvm_stark_sdk::config::baby_bear_poseidon2::F; +use ceno_zkvm::structs::VK_DIGEST_LEN; +use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, F}; use p3_field::PrimeCharacteristicRing; use p3_matrix::dense::RowMajorMatrix; @@ -8,6 +9,7 @@ use crate::{ proof_shape::pvs::PublicValuesCols, system::{Preflight, RecursionField, RecursionProof, RecursionVk}, tracegen::RowMajorChip, + utils::TranscriptLabel, }; pub struct PublicValuesTraceGenerator; @@ -43,8 +45,7 @@ impl RowMajorChip for PublicValuesTraceGenerator { for (proof_idx, proof) in proofs.iter().enumerate() { let mut is_first_in_proof = true; - // TODO first tidx start from TranscriptLabel::Riscv.field_len() - let mut tidx = 0usize; + let mut tidx = TranscriptLabel::Riscv.field_len() + VK_DIGEST_LEN * D_EF; for (air_idx, (_, circuit_vk)) in child_vk.circuit_vks.iter().enumerate() { let instance_openings = &circuit_vk.get_cs().zkvm_v1_css.instance; diff --git a/ceno_recursion_v2/src/system/bus_inventory.rs b/ceno_recursion_v2/src/system/bus_inventory.rs index 4c304d542..e3db36d49 100644 --- a/ceno_recursion_v2/src/system/bus_inventory.rs +++ b/ceno_recursion_v2/src/system/bus_inventory.rs @@ -13,7 +13,8 @@ use crate::bus::{ FractionFolderInputBus as LocalFractionFolderInputBus, HyperdimBus as LocalHyperdimBus, LiftedHeightsBus as LocalLiftedHeightsBus, LookupChallengeBus, MainBus, MainExpressionClaimBus, MainSumcheckInputBus, MainSumcheckOutputBus, NLiftBus as LocalNLiftBus, - PublicValuesBus as LocalPublicValuesBus, TowerModuleBus, TranscriptBus as LocalTranscriptBus, + PublicValuesBus as LocalPublicValuesBus, TowerModuleBus, TowerRootClaimBus, + TranscriptBus as LocalTranscriptBus, }; #[derive(Clone, Debug)] @@ -23,6 +24,7 @@ pub struct BusInventory { pub poseidon2_compress_bus: Poseidon2CompressBus, pub merkle_verify_bus: MerkleVerifyBus, pub tower_module_bus: TowerModuleBus, + pub tower_root_claim_bus: TowerRootClaimBus, pub expression_claim_n_max_bus: LocalExpressionClaimNMaxBus, pub fraction_folder_input_bus: LocalFractionFolderInputBus, pub air_shape_bus: AirShapeBus, @@ -54,6 +56,7 @@ impl BusInventory { let gkr_bus_idx = b.new_bus_idx(); let tower_module_bus = TowerModuleBus::new(gkr_bus_idx); + let tower_root_claim_bus = TowerRootClaimBus::new(b.new_bus_idx()); let air_shape_bus = AirShapeBus::new(b.new_bus_idx()); let hyperdim_bus = LocalHyperdimBus::new(b.new_bus_idx()); @@ -85,6 +88,7 @@ impl BusInventory { poseidon2_compress_bus, merkle_verify_bus, tower_module_bus, + tower_root_claim_bus, expression_claim_n_max_bus, fraction_folder_input_bus, air_shape_bus, diff --git a/ceno_recursion_v2/src/system/mod.rs b/ceno_recursion_v2/src/system/mod.rs index 0012264cb..68dc0774e 100644 --- a/ceno_recursion_v2/src/system/mod.rs +++ b/ceno_recursion_v2/src/system/mod.rs @@ -251,6 +251,28 @@ impl VerifierSubCircuit { Self::new_with_options(child_vk, VerifierConfig::default()) } + fn vm_pvs_lookup_challenges_from_transcript(sponge: &TS) -> (EF, EF) + where + TS: Clone + TranscriptHistory, + { + let log = sponge.clone().into_log(); + let values = log.values(); + let samples = log.samples(); + let start = values + .len() + .checked_sub(2 * D_EF) + .expect("VM PVS transcript must end with alpha/beta extension samples"); + debug_assert!( + samples[start..start + 2 * D_EF] + .iter() + .all(|is_sample| *is_sample), + "VM PVS transcript suffix must contain alpha/beta samples" + ); + let alpha = EF::from_basis_coefficients_fn(|i| values[start + i]); + let beta = EF::from_basis_coefficients_fn(|i| values[start + D_EF + i]); + (alpha, beta) + } + pub fn new_with_options(child_vk: Arc, config: VerifierConfig) -> Self { // let child_mvk = convert_vk_from_zkvm(child_vk.as_ref()); // let proof_shape_constraint = LinearConstraint { @@ -350,16 +372,22 @@ impl VerifierSubCircuit { + Clone, { let mut preflight = Preflight::default(); + let (alpha_ext, beta_ext) = Self::vm_pvs_lookup_challenges_from_transcript(&sponge); + preflight.vm_pvs.lookup_challenge_alpha = alpha_ext; + preflight.vm_pvs.lookup_challenge_beta = beta_ext; // Phase 1: Trunk operations. // Proof-shape metadata and alpha/beta sampling after pre-verifier transcript observes. self.proof_shape .run_preflight(child_vk, proof, &mut preflight, &mut sponge); - - // VmPvs is owned by pre-system preflight. Consume vm_pvs challenge - // fields directly here. - let alpha_ext = preflight.vm_pvs.lookup_challenge_alpha; - let beta_ext = preflight.vm_pvs.lookup_challenge_beta; + let num_lookup_challenge_consumers = preflight.proof_shape.sorted_trace_vdata.len(); + preflight.vm_pvs.lookup_challenge_alpha_lookup_count = num_lookup_challenge_consumers; + preflight.vm_pvs.lookup_challenge_beta_lookup_count = num_lookup_challenge_consumers; + + // VmPvs is owned by the pre-system preflight. The incoming transcript + // has already sampled these two challenges; recover and forward them + // so forked chip transcripts bind the same alpha/beta as the native + // verifier. preflight.proof_shape.lookup_challenge_alpha = ef_to_limbs(alpha_ext); preflight.proof_shape.lookup_challenge_beta = ef_to_limbs(beta_ext); @@ -400,13 +428,26 @@ impl VerifierSubCircuit { } let tower_tidx = fs.len(); - let tower_replay = + let (tower_schedule, tower_replay) = crate::tower::record_and_replay_tower_preflight(fs, child_vk, chip_idx, chip_proof); + let main_claim = crate::tower::derive_tower_input_claim_for_transcript( + child_vk, + chip_idx, + chip_proof, + &tower_replay, + &tower_schedule, + ); + // Record tower entry with fork-local tidx at tower stage start. preflight.gkr.chips.push(TowerChipTranscriptRange { chip_idx, instance_idx, + num_layers: crate::tower::circuit_vk_for_idx(child_vk, chip_idx) + .map(|circuit_vk| { + crate::tower::tower_layer_count_from_vk(circuit_vk, chip_proof) + }) + .unwrap_or(0), tidx: tower_tidx, fork_idx: fork_id, tower_replay, @@ -414,7 +455,7 @@ impl VerifierSubCircuit { // Main preflight for this chip. let main_tidx = fs.len(); - crate::main::record_main_transcript(fs, chip_idx, chip_proof); + crate::main::record_main_transcript(fs, main_claim); preflight.main.chips.push(ChipTranscriptRange { chip_idx, diff --git a/ceno_recursion_v2/src/system/preflight/mod.rs b/ceno_recursion_v2/src/system/preflight/mod.rs index ae05071b0..c8bfdf9ad 100644 --- a/ceno_recursion_v2/src/system/preflight/mod.rs +++ b/ceno_recursion_v2/src/system/preflight/mod.rs @@ -93,6 +93,8 @@ pub struct TowerPreflight { pub struct TowerChipTranscriptRange { pub chip_idx: usize, pub instance_idx: usize, + /// Number of tower layers represented by AIR rows. + pub num_layers: usize, /// Fork-local tidx (position within the fork's transcript log). pub tidx: usize, /// Index into `Preflight::fork_transcripts`. diff --git a/ceno_recursion_v2/src/tower/bus.rs b/ceno_recursion_v2/src/tower/bus.rs index ab41c3c30..48fe9003a 100644 --- a/ceno_recursion_v2/src/tower/bus.rs +++ b/ceno_recursion_v2/src/tower/bus.rs @@ -6,7 +6,7 @@ use crate::define_typed_per_proof_permutation_bus; #[repr(C)] #[derive(AlignedBorrow, Debug, Clone)] pub struct TowerXiSamplerMessage { - pub idx: T, + pub chip_id: T, pub tidx: T, } @@ -16,11 +16,13 @@ define_typed_per_proof_permutation_bus!(TowerXiSamplerBus, TowerXiSamplerMessage #[repr(C)] #[derive(AlignedBorrow, Debug, Clone)] pub struct TowerLayerInputMessage { - pub idx: T, + pub chip_id: T, pub tidx: T, - pub r0_claim: [T; D_EF], - pub w0_claim: [T; D_EF], - pub q0_claim: [T; D_EF], + pub num_layers: T, + pub num_read_specs: T, + pub num_write_specs: T, + pub num_logup_specs: T, + pub initial_tower_claim: [T; D_EF], } define_typed_per_proof_permutation_bus!(TowerLayerInputBus, TowerLayerInputMessage); @@ -29,11 +31,11 @@ define_typed_per_proof_permutation_bus!(TowerLayerInputBus, TowerLayerInputMessa #[repr(C)] #[derive(AlignedBorrow, Debug, Clone)] pub struct TowerLayerOutputMessage { - pub idx: T, + pub chip_id: T, pub tidx: T, pub layer_idx_end: T, pub input_layer_claim: [T; D_EF], - pub lambda: [T; D_EF], + pub lambda_next: [T; D_EF], pub mu: [T; D_EF], } @@ -41,29 +43,31 @@ define_typed_per_proof_permutation_bus!(TowerLayerOutputBus, TowerLayerOutputMes #[repr(C)] #[derive(AlignedBorrow, Debug, Clone)] -pub struct TowerProdLayerChallengeMessage { - pub idx: T, +pub struct TowerProdLayerInputMessage { + pub chip_id: T, pub layer_idx: T, pub tidx: T, - pub lambda: [T; D_EF], - pub lambda_prime: [T; D_EF], + pub lambda_next: [T; D_EF], + pub lambda_cur: [T; D_EF], pub mu: [T; D_EF], + pub prod_offset: T, + pub lambda_next_start: [T; D_EF], + pub lambda_cur_start: [T; D_EF], + pub num_prod_count: T, } -define_typed_per_proof_permutation_bus!(TowerProdReadClaimInputBus, TowerProdLayerChallengeMessage); -define_typed_per_proof_permutation_bus!( - TowerProdWriteClaimInputBus, - TowerProdLayerChallengeMessage -); +define_typed_per_proof_permutation_bus!(TowerProdReadClaimInputBus, TowerProdLayerInputMessage); +define_typed_per_proof_permutation_bus!(TowerProdWriteClaimInputBus, TowerProdLayerInputMessage); #[repr(C)] #[derive(AlignedBorrow, Debug, Clone)] pub struct TowerProdSumClaimMessage { - pub idx: T, + pub chip_id: T, pub layer_idx: T, - pub lambda_claim: [T; D_EF], - pub lambda_prime_claim: [T; D_EF], - pub num_prod_count: T, + pub lambda_next_claim: [T; D_EF], + pub lambda_cur_claim: [T; D_EF], + pub lambda_next_end: [T; D_EF], + pub lambda_cur_end: [T; D_EF], } define_typed_per_proof_permutation_bus!(TowerProdReadClaimBus, TowerProdSumClaimMessage); @@ -72,12 +76,15 @@ define_typed_per_proof_permutation_bus!(TowerProdWriteClaimBus, TowerProdSumClai #[repr(C)] #[derive(AlignedBorrow, Debug, Clone)] pub struct TowerLogupLayerChallengeMessage { - pub idx: T, + pub chip_id: T, pub layer_idx: T, pub tidx: T, - pub lambda: [T; D_EF], - pub lambda_prime: [T; D_EF], + pub lambda_next: [T; D_EF], + pub lambda_cur: [T; D_EF], pub mu: [T; D_EF], + pub lambda_next_start: [T; D_EF], + pub lambda_cur_start: [T; D_EF], + pub num_logup_count: T, } define_typed_per_proof_permutation_bus!(TowerLogupClaimInputBus, TowerLogupLayerChallengeMessage); @@ -85,21 +92,86 @@ define_typed_per_proof_permutation_bus!(TowerLogupClaimInputBus, TowerLogupLayer #[repr(C)] #[derive(AlignedBorrow, Debug, Clone)] pub struct TowerLogupClaimMessage { - pub idx: T, + pub chip_id: T, pub layer_idx: T, - pub lambda_claim: [T; D_EF], - pub lambda_prime_claim: [T; D_EF], - pub num_logup_count: T, + pub lambda_next_claim: [T; D_EF], + pub lambda_cur_claim: [T; D_EF], } define_typed_per_proof_permutation_bus!(TowerLogupClaimBus, TowerLogupClaimMessage); +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerProdRootInputMessage { + pub chip_id: T, + pub tidx: T, + pub lambda_1: [T; D_EF], + pub r_1: [T; D_EF], + pub lambda_1_start: [T; D_EF], + pub num_prod_count: T, +} + +define_typed_per_proof_permutation_bus!(TowerReadRootInputBus, TowerProdRootInputMessage); +define_typed_per_proof_permutation_bus!(TowerWriteRootInputBus, TowerProdRootInputMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerProdRootMessage { + pub chip_id: T, + pub output_claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerReadRootBus, TowerProdRootMessage); +define_typed_per_proof_permutation_bus!(TowerWriteRootBus, TowerProdRootMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerProdInitMessage { + pub chip_id: T, + pub initial_claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerReadInitBus, TowerProdInitMessage); +define_typed_per_proof_permutation_bus!(TowerWriteInitBus, TowerProdInitMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerLogupRootInputMessage { + pub chip_id: T, + pub tidx: T, + pub lambda_1: [T; D_EF], + pub r_1: [T; D_EF], + pub lambda_1_start: [T; D_EF], + pub num_logup_count: T, +} + +define_typed_per_proof_permutation_bus!(TowerLogupRootInputBus, TowerLogupRootInputMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerLogupRootMessage { + pub chip_id: T, + pub p0_claim: [T; D_EF], + pub q0_claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerLogupRootBus, TowerLogupRootMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerLogupInitMessage { + pub chip_id: T, + pub initial_claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerLogupInitBus, TowerLogupInitMessage); + /// Message sent from TowerLayerAir to TowerLayerSumcheckAir #[repr(C)] #[derive(AlignedBorrow, Debug, Clone)] pub struct TowerSumcheckInputMessage { /// Module index within the proof - pub idx: T, + pub chip_id: T, /// GKR layer index pub layer_idx: T, pub is_last_layer: T, @@ -116,7 +188,7 @@ define_typed_per_proof_permutation_bus!(TowerSumcheckInputBus, TowerSumcheckInpu #[derive(AlignedBorrow, Debug, Clone)] pub struct TowerSumcheckOutputMessage { /// Module index within the proof - pub idx: T, + pub chip_id: T, /// GKR layer index pub layer_idx: T, /// Transcript index after sumcheck @@ -134,7 +206,7 @@ define_typed_per_proof_permutation_bus!(TowerSumcheckOutputBus, TowerSumcheckOut #[derive(AlignedBorrow, Debug, Clone)] pub struct TowerSumcheckChallengeMessage { /// Module index within the proof - pub idx: T, + pub chip_id: T, /// GKR layer index pub layer_idx: T, /// Sumcheck round number diff --git a/ceno_recursion_v2/src/tower/input/air.rs b/ceno_recursion_v2/src/tower/input/air.rs index d30ed4c7e..a3642b289 100644 --- a/ceno_recursion_v2/src/tower/input/air.rs +++ b/ceno_recursion_v2/src/tower/input/air.rs @@ -1,15 +1,22 @@ use core::borrow::Borrow; use crate::{ - bus::{MainBus, MainMessage, TowerModuleBus, TowerModuleMessage, TranscriptBus}, + bus::{ + MainBus, MainMessage, TowerModuleBus, TowerModuleMessage, TowerRootClaimBus, + TowerRootClaimMessage, TranscriptBus, + }, tower::bus::{ TowerLayerInputBus, TowerLayerInputMessage, TowerLayerOutputBus, TowerLayerOutputMessage, + TowerLogupInitBus, TowerLogupInitMessage, TowerLogupRootBus, TowerLogupRootInputBus, + TowerLogupRootInputMessage, TowerLogupRootMessage, TowerProdInitMessage, + TowerProdRootInputMessage, TowerProdRootMessage, TowerReadInitBus, TowerReadRootBus, + TowerReadRootInputBus, TowerWriteInitBus, TowerWriteRootBus, TowerWriteRootInputBus, }, }; use openvm_circuit_primitives::{ SubAir, is_zero::{IsZeroAuxCols, IsZeroIo, IsZeroSubAir}, - utils::not, + utils::{assert_array_eq, not}, }; use openvm_stark_backend::{ BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, @@ -18,10 +25,7 @@ use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{Field, PrimeCharacteristicRing}; use p3_matrix::Matrix; -use recursion_circuit::{ - subairs::proof_idx::{ProofIdxIoCols, ProofIdxSubAir}, - utils::assert_zeros, -}; +use recursion_circuit::utils::assert_zeros; use stark_recursion_circuit_derive::AlignedBorrow; #[repr(C)] @@ -32,23 +36,38 @@ pub struct TowerInputCols { pub proof_idx: T, pub idx: T, + pub chip_id: T, - pub n_logup: T, + pub num_layers: T, + pub num_read_specs: T, + pub num_write_specs: T, + pub num_logup_specs: T, /// Flag indicating whether there are any interactions - /// n_logup = 0 <=> total_interactions = 0 - pub is_n_logup_zero: T, - pub is_n_logup_zero_aux: IsZeroAuxCols, + /// num_layers = 0 <=> total_interactions = 0 + pub is_num_layers_zero: T, + pub is_num_layers_zero_aux: IsZeroAuxCols, /// Transcript index pub tidx: T, + pub final_tidx: T, pub r0_claim: [T; D_EF], pub w0_claim: [T; D_EF], + /// Root numerator claim + pub p0_claim: [T; D_EF], /// Root denominator claim pub q0_claim: [T; D_EF], pub alpha_logup: [T; D_EF], + pub r_1: [T; D_EF], + + pub read_initial_claim: [T; D_EF], + pub write_initial_claim: [T; D_EF], + pub logup_initial_claim: [T; D_EF], + pub initial_tower_claim: [T; D_EF], + pub write_lambda_1_start: [T; D_EF], + pub logup_lambda_1_start: [T; D_EF], pub input_layer_claim: [T; D_EF], pub layer_output_lambda: [T; D_EF], @@ -59,10 +78,20 @@ pub struct TowerInputCols { pub struct TowerInputAir { // Buses pub tower_module_bus: TowerModuleBus, + pub tower_root_claim_bus: TowerRootClaimBus, pub main_bus: MainBus, pub transcript_bus: TranscriptBus, pub layer_input_bus: TowerLayerInputBus, pub layer_output_bus: TowerLayerOutputBus, + pub read_root_input_bus: TowerReadRootInputBus, + pub read_root_bus: TowerReadRootBus, + pub read_init_bus: TowerReadInitBus, + pub write_root_input_bus: TowerWriteRootInputBus, + pub write_root_bus: TowerWriteRootBus, + pub write_init_bus: TowerWriteInitBus, + pub logup_root_input_bus: TowerLogupRootInputBus, + pub logup_root_bus: TowerLogupRootBus, + pub logup_init_bus: TowerLogupInitBus, } impl BaseAir for TowerInputAir { @@ -88,40 +117,48 @@ impl Air for TowerInputAir { // Proof Index Constraints /////////////////////////////////////////////////////////////////////// - // This subair has the following constraints: - // 1. Boolean enabled flag - // 2. Disabled rows are followed by disabled rows - // 3. Proof index increments by exactly one between enabled rows - ProofIdxSubAir.eval( - builder, - ( - ProofIdxIoCols { - is_enabled: local.is_enabled, - proof_idx: local.proof_idx, - } - .map_into(), - ProofIdxIoCols { - is_enabled: next.is_enabled, - proof_idx: next.proof_idx, - } - .map_into(), - ), - ); + builder.assert_bool(local.is_enabled); + builder + .when_transition() + .when(AB::Expr::ONE - local.is_enabled) + .assert_zero(next.is_enabled); + builder + .when_first_row() + .when(local.is_enabled) + .assert_zero(local.proof_idx); + builder + .when_first_row() + .when(local.is_enabled) + .assert_zero(local.idx); + + let proof_diff: AB::Expr = next.proof_idx - local.proof_idx; + builder + .when_transition() + .when(next.is_enabled) + .assert_bool(proof_diff.clone()); + builder + .when_transition() + .when(next.is_enabled * proof_diff.clone()) + .assert_zero(next.idx); + builder + .when_transition() + .when(next.is_enabled * (AB::Expr::ONE - proof_diff)) + .assert_eq(next.idx, local.idx + AB::Expr::ONE); /////////////////////////////////////////////////////////////////////// // Base Constraints /////////////////////////////////////////////////////////////////////// - // 1. Check if n_logup is zero (no logup constraints needed) + // 1. Check if num_layers is zero (no tower reduction needed) IsZeroSubAir.eval( builder, ( IsZeroIo::new( - local.n_logup.into(), - local.is_n_logup_zero.into(), + local.num_layers.into(), + local.is_num_layers_zero.into(), local.is_enabled.into(), ), - local.is_n_logup_zero_aux.inv, + local.is_num_layers_zero_aux.inv, ), ); @@ -129,7 +166,19 @@ impl Air for TowerInputAir { // Output Constraints /////////////////////////////////////////////////////////////////////// - let has_interactions = AB::Expr::ONE - local.is_n_logup_zero; + let has_interactions = AB::Expr::ONE - local.is_num_layers_zero; + let initial_sum = { + let read_plus_write = recursion_circuit::utils::ext_field_add::( + local.read_initial_claim, + local.write_initial_claim, + ); + recursion_circuit::utils::ext_field_add::( + read_plus_write, + local.logup_initial_claim, + ) + }; + assert_array_eq(builder, local.initial_tower_claim, initial_sum); + // Input layer claim defaults to zero when no interactions assert_zeros( &mut builder.when(not::(has_interactions.clone())), @@ -148,35 +197,126 @@ impl Air for TowerInputAir { // Module Interactions /////////////////////////////////////////////////////////////////////// - let num_layers = local.n_logup; + let num_layers = local.num_layers; + let prod_eval_span = AB::Expr::from_usize(2 * D_EF); + let read_claim_tidx = local.tidx; + let write_claim_tidx = read_claim_tidx + local.num_read_specs * prod_eval_span.clone(); + let logup_claim_tidx = write_claim_tidx.clone() + local.num_write_specs * prod_eval_span; + let one = { + let mut arr = core::array::from_fn(|_| AB::Expr::ZERO); + arr[0] = AB::Expr::ONE; + arr + }; + + self.read_root_input_bus.send( + builder, + local.proof_idx, + TowerProdRootInputMessage { + chip_id: local.chip_id.into(), + tidx: read_claim_tidx.into(), + lambda_1: local.alpha_logup.map(Into::into), + r_1: local.r_1.map(Into::into), + lambda_1_start: one.clone(), + num_prod_count: local.num_read_specs.into(), + }, + local.is_enabled * local.num_read_specs, + ); + self.write_root_input_bus.send( + builder, + local.proof_idx, + TowerProdRootInputMessage { + chip_id: local.chip_id.into(), + tidx: write_claim_tidx, + lambda_1: local.alpha_logup.map(Into::into), + r_1: local.r_1.map(Into::into), + lambda_1_start: local.write_lambda_1_start.map(Into::into), + num_prod_count: local.num_write_specs.into(), + }, + local.is_enabled * local.num_write_specs, + ); + self.logup_root_input_bus.send( + builder, + local.proof_idx, + TowerLogupRootInputMessage { + chip_id: local.chip_id.into(), + tidx: logup_claim_tidx, + lambda_1: local.alpha_logup.map(Into::into), + r_1: local.r_1.map(Into::into), + lambda_1_start: local.logup_lambda_1_start.map(Into::into), + num_logup_count: local.num_logup_specs.into(), + }, + local.is_enabled * local.num_logup_specs, + ); + self.read_root_bus.receive( + builder, + local.proof_idx, + TowerProdRootMessage { + chip_id: local.chip_id.into(), + output_claim: local.r0_claim.map(Into::into), + }, + local.is_enabled * local.num_read_specs, + ); + self.write_root_bus.receive( + builder, + local.proof_idx, + TowerProdRootMessage { + chip_id: local.chip_id.into(), + output_claim: local.w0_claim.map(Into::into), + }, + local.is_enabled * local.num_write_specs, + ); + self.logup_root_bus.receive( + builder, + local.proof_idx, + TowerLogupRootMessage { + chip_id: local.chip_id.into(), + p0_claim: local.p0_claim.map(Into::into), + q0_claim: local.q0_claim.map(Into::into), + }, + local.is_enabled * local.num_logup_specs, + ); + self.read_init_bus.receive( + builder, + local.proof_idx, + TowerProdInitMessage { + chip_id: local.chip_id.into(), + initial_claim: local.read_initial_claim.map(Into::into), + }, + local.is_enabled * local.num_read_specs, + ); + self.write_init_bus.receive( + builder, + local.proof_idx, + TowerProdInitMessage { + chip_id: local.chip_id.into(), + initial_claim: local.write_initial_claim.map(Into::into), + }, + local.is_enabled * local.num_write_specs, + ); + self.logup_init_bus.receive( + builder, + local.proof_idx, + TowerLogupInitMessage { + chip_id: local.chip_id.into(), + initial_claim: local.logup_initial_claim.map(Into::into), + }, + local.is_enabled * local.num_logup_specs, + ); // Add PoW (if any) and alpha label+sample, beta label+sample - use crate::tower::tower_transcript_len::{ - ALPHA_BETA_LEN, ALPHA_LEN, POST_SUMCHECK_LEN, ROUND_LEN, SUMCHECK_INIT_LEN, - }; - let tidx_after_alpha_beta = local.tidx + AB::Expr::from_usize(ALPHA_BETA_LEN); - // Add GKR layers + Sumcheck. - // Total GKR span: n*(10n+25) - 13 for n>0. - // layers_cumulative(n) = 10n² + 25n - 13. - let gkr_inner = num_layers.clone() * AB::Expr::from_usize(ROUND_LEN / 2) - + AB::Expr::from_usize( - ALPHA_LEN + SUMCHECK_INIT_LEN + POST_SUMCHECK_LEN - ROUND_LEN / 2, - ); - let tidx_after_gkr_layers = tidx_after_alpha_beta.clone() - + has_interactions.clone() - * (num_layers.clone() * gkr_inner - - AB::Expr::from_usize(ALPHA_LEN + SUMCHECK_INIT_LEN)); // 1. TowerLayerInputBus // 1a. Send input to TowerLayerAir self.layer_input_bus.send( builder, local.proof_idx, TowerLayerInputMessage { - idx: local.idx.into(), - tidx: tidx_after_alpha_beta.clone() * has_interactions.clone(), - r0_claim: local.r0_claim.map(Into::into), - w0_claim: local.w0_claim.map(Into::into), - q0_claim: local.q0_claim.map(Into::into), + chip_id: local.chip_id.into(), + tidx: local.tidx.into(), + num_layers: local.num_layers.into(), + num_read_specs: local.num_read_specs.into(), + num_write_specs: local.num_write_specs.into(), + num_logup_specs: local.num_logup_specs.into(), + initial_tower_claim: local.initial_tower_claim.map(Into::into), }, local.is_enabled * has_interactions.clone(), ); @@ -186,11 +326,11 @@ impl Air for TowerInputAir { builder, local.proof_idx, TowerLayerOutputMessage { - idx: local.idx.into(), - tidx: tidx_after_gkr_layers.clone(), + chip_id: local.chip_id.into(), + tidx: local.final_tidx.into(), layer_idx_end: num_layers - AB::Expr::ONE, input_layer_claim: local.input_layer_claim.map(Into::into), - lambda: local.layer_output_lambda.map(Into::into), + lambda_next: local.layer_output_lambda.map(Into::into), mu: local.layer_output_mu.map(Into::into), }, local.is_enabled * has_interactions.clone(), @@ -205,28 +345,34 @@ impl Air for TowerInputAir { builder, local.proof_idx, TowerModuleMessage { - idx: local.idx.into(), - tidx: local.tidx.into(), - n_logup: local.n_logup.into(), + chip_id: local.chip_id.into(), + num_layers: local.num_layers.into(), + num_read_specs: local.num_read_specs.into(), + num_write_specs: local.num_write_specs.into(), + num_logup_specs: local.num_logup_specs.into(), }, local.is_enabled, ); - // 2. TranscriptBus - // 2a. Sample alpha_logup challenge - self.transcript_bus.sample_ext( + self.tower_root_claim_bus.send( builder, local.proof_idx, - local.tidx, - local.alpha_logup.map(Into::into), + TowerRootClaimMessage { + chip_id: local.chip_id.into(), + r0_claim: local.r0_claim.map(Into::into), + w0_claim: local.w0_claim.map(Into::into), + p0_claim: local.p0_claim.map(Into::into), + q0_claim: local.q0_claim.map(Into::into), + }, local.is_enabled, ); + self.main_bus.send( builder, local.proof_idx, MainMessage { - idx: local.idx.into(), - tidx: tidx_after_gkr_layers.clone(), + chip_id: local.chip_id.into(), + tidx: local.final_tidx.into(), claim: local.input_layer_claim.map(Into::into), }, local.is_enabled * has_interactions, diff --git a/ceno_recursion_v2/src/tower/input/trace.rs b/ceno_recursion_v2/src/tower/input/trace.rs index 6b7340604..7c3d56dc0 100644 --- a/ceno_recursion_v2/src/tower/input/trace.rs +++ b/ceno_recursion_v2/src/tower/input/trace.rs @@ -11,9 +11,25 @@ use p3_matrix::dense::RowMajorMatrix; pub struct TowerInputRecord { pub proof_idx: usize, pub idx: usize, + pub chip_id: usize, pub tidx: usize, - pub n_logup: usize, + pub final_tidx: usize, + pub num_layers: usize, + pub num_read_specs: usize, + pub num_write_specs: usize, + pub num_logup_specs: usize, + pub r0_claim: EF, + pub w0_claim: EF, + pub p0_claim: EF, + pub q0_claim: EF, pub alpha_logup: EF, + pub r_1: EF, + pub read_initial_claim: EF, + pub write_initial_claim: EF, + pub logup_initial_claim: EF, + pub initial_tower_claim: EF, + pub write_lambda_1_start: EF, + pub logup_lambda_1_start: EF, pub input_layer_claim: EF, pub layer_output_lambda: EF, pub layer_output_mu: EF, @@ -22,8 +38,7 @@ pub struct TowerInputRecord { pub struct TowerInputTraceGenerator; impl RowMajorChip for TowerInputTraceGenerator { - // (gkr_input_records, q0_claims) - type Ctx<'a> = (&'a [TowerInputRecord], &'a [EF]); + type Ctx<'a> = &'a [TowerInputRecord]; #[tracing::instrument(level = "trace", skip_all)] fn generate_trace( @@ -31,8 +46,7 @@ impl RowMajorChip for TowerInputTraceGenerator { ctx: &Self::Ctx<'_>, required_height: Option, ) -> Option> { - let (gkr_input_records, q0_claims) = ctx; - debug_assert_eq!(gkr_input_records.len(), q0_claims.len()); + let gkr_input_records = *ctx; let width = TowerInputCols::::width(); @@ -50,33 +64,88 @@ impl RowMajorChip for TowerInputTraceGenerator { let (data_slice, _) = trace.split_at_mut(num_valid_rows * width); - for (row_data, (record, q0_claim)) in data_slice + for (row_data, record) in data_slice .chunks_exact_mut(width) - .zip(gkr_input_records.iter().zip(q0_claims.iter())) + .zip(gkr_input_records.iter()) { let cols: &mut TowerInputCols = row_data.borrow_mut(); cols.is_enabled = F::ONE; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.tidx = F::from_usize(record.tidx); + cols.final_tidx = F::from_usize(record.final_tidx); - cols.n_logup = F::from_usize(record.n_logup); + cols.num_layers = F::from_usize(record.num_layers); + cols.num_read_specs = F::from_usize(record.num_read_specs); + cols.num_write_specs = F::from_usize(record.num_write_specs); + cols.num_logup_specs = F::from_usize(record.num_logup_specs); IsZeroSubAir.generate_subrow( - cols.n_logup, - (&mut cols.is_n_logup_zero_aux.inv, &mut cols.is_n_logup_zero), + cols.num_layers, + ( + &mut cols.is_num_layers_zero_aux.inv, + &mut cols.is_num_layers_zero, + ), ); - let q0_basis = q0_claim.as_basis_coefficients_slice(); - cols.r0_claim.copy_from_slice(q0_basis); - cols.w0_claim.copy_from_slice(q0_basis); - cols.q0_claim.copy_from_slice(q0_basis); + cols.r0_claim = record + .r0_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.w0_claim = record + .w0_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.p0_claim = record + .p0_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.q0_claim = record + .q0_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); cols.alpha_logup = record .alpha_logup .as_basis_coefficients_slice() .try_into() .unwrap(); + cols.r_1 = record.r_1.as_basis_coefficients_slice().try_into().unwrap(); + cols.read_initial_claim = record + .read_initial_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.write_initial_claim = record + .write_initial_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.logup_initial_claim = record + .logup_initial_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.initial_tower_claim = record + .initial_tower_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.write_lambda_1_start = record + .write_lambda_1_start + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.logup_lambda_1_start = record + .logup_lambda_1_start + .as_basis_coefficients_slice() + .try_into() + .unwrap(); cols.input_layer_claim = record .input_layer_claim .as_basis_coefficients_slice() diff --git a/ceno_recursion_v2/src/tower/layer/air.rs b/ceno_recursion_v2/src/tower/layer/air.rs index ce786b746..0e9522516 100644 --- a/ceno_recursion_v2/src/tower/layer/air.rs +++ b/ceno_recursion_v2/src/tower/layer/air.rs @@ -11,18 +11,16 @@ use p3_matrix::Matrix; use stark_recursion_circuit_derive::AlignedBorrow; use crate::{ - bus::{AirShapeBus, AirShapeBusMessage}, - proof_shape::bus::AirShapeProperty, + bus::AirShapeBus, tower::{ TowerSumcheckChallengeBus, TowerSumcheckChallengeMessage, bus::{ TowerLayerInputBus, TowerLayerInputMessage, TowerLayerOutputBus, TowerLayerOutputMessage, TowerLogupClaimBus, TowerLogupClaimInputBus, - TowerLogupClaimMessage, TowerLogupLayerChallengeMessage, - TowerProdLayerChallengeMessage, TowerProdReadClaimBus, TowerProdReadClaimInputBus, - TowerProdSumClaimMessage, TowerProdWriteClaimBus, TowerProdWriteClaimInputBus, - TowerSumcheckInputBus, TowerSumcheckInputMessage, TowerSumcheckOutputBus, - TowerSumcheckOutputMessage, + TowerLogupClaimMessage, TowerLogupLayerChallengeMessage, TowerProdLayerInputMessage, + TowerProdReadClaimBus, TowerProdReadClaimInputBus, TowerProdSumClaimMessage, + TowerProdWriteClaimBus, TowerProdWriteClaimInputBus, TowerSumcheckInputBus, + TowerSumcheckInputMessage, TowerSumcheckOutputBus, TowerSumcheckOutputMessage, }, }, }; @@ -40,6 +38,7 @@ pub struct TowerLayerCols { pub is_enabled: T, pub proof_idx: T, pub idx: T, + pub chip_id: T, pub is_first_air_idx: T, pub is_first: T, @@ -68,16 +67,22 @@ pub struct TowerLayerCols { pub write_claim_prime: [T; D_EF], pub logup_claim: [T; D_EF], pub logup_claim_prime: [T; D_EF], + pub read_eval_claim: [T; D_EF], + pub write_eval_claim: [T; D_EF], + pub logup_eval_claim: [T; D_EF], + pub read_lambda_end: [T; D_EF], + pub read_lambda_prime_end: [T; D_EF], + pub write_lambda_end: [T; D_EF], + pub write_lambda_prime_end: [T; D_EF], pub num_read_count: T, pub num_write_count: T, pub num_logup_count: T, + pub num_layers: T, /// Received from TowerLayerSumcheckAir pub eq_at_r_prime: [T; D_EF], - pub r0_claim: [T; D_EF], - pub w0_claim: [T; D_EF], - pub q0_claim: [T; D_EF], + pub initial_tower_claim: [T; D_EF], } /// The TowerLayerAir handles layer-to-layer transitions in the GKR protocol @@ -198,22 +203,51 @@ where let read_plus_write = ext_field_add::(local.read_claim, local.write_claim); let folded_claim = ext_field_add::(read_plus_write, local.logup_claim); + assert_array_eq( + &mut builder.when(local.is_first), + folded_claim.clone(), + local.initial_tower_claim, + ); assert_array_eq( &mut builder.when(is_transition.clone()), next.sumcheck_claim_in, folded_claim.clone(), ); + assert_array_eq( + &mut builder.when(is_transition.clone()), + next.read_eval_claim, + local.read_claim_prime, + ); + assert_array_eq( + &mut builder.when(is_transition.clone()), + next.write_eval_claim, + local.write_claim_prime, + ); + assert_array_eq( + &mut builder.when(is_transition.clone()), + next.logup_eval_claim, + local.logup_claim_prime, + ); // Transcript index increment use crate::tower::tower_transcript_len::{ - ALPHA_LEN, POST_SUMCHECK_LEN, ROUND_LEN, SUMCHECK_INIT_LEN, + ALPHA_BETA_LEN, ALPHA_LEN, LABEL_COMBINE, LABEL_COMBINE_VALUES, LABEL_MERGE, + LABEL_MERGE_VALUES, LABEL_PRODUCT_SUM, LABEL_PRODUCT_SUM_VALUES, MERGE_LEN, ROUND_LEN, + SUMCHECK_INIT_LEN, }; - let tidx_after_sumcheck = local.tidx - // Sample lambda label+sample on non-root layer - + (AB::Expr::ONE - local.is_first) - * AB::Expr::from_usize(ALPHA_LEN + SUMCHECK_INIT_LEN) + let out_eval_span = (local.num_read_count * AB::Expr::from_usize(2) + + local.num_write_count * AB::Expr::from_usize(2) + + local.num_logup_count * AB::Expr::from_usize(4)) + * AB::Expr::from_usize(D_EF); + let non_root = AB::Expr::ONE - local.is_first; + let sumcheck_span = AB::Expr::from_usize(SUMCHECK_INIT_LEN) + local.layer_idx * AB::Expr::from_usize(ROUND_LEN); - let tidx_end = tidx_after_sumcheck.clone() + AB::Expr::from_usize(POST_SUMCHECK_LEN); + let tidx_after_sumcheck = local.tidx + non_root.clone() * sumcheck_span.clone(); + let root_span = out_eval_span.clone() + AB::Expr::from_usize(ALPHA_BETA_LEN); + let non_root_span = + sumcheck_span + out_eval_span.clone() + AB::Expr::from_usize(MERGE_LEN + ALPHA_LEN); + let layer_span = local.is_first * root_span + non_root.clone() * non_root_span; + let tidx_end = local.tidx + layer_span; builder .when(is_transition.clone()) .assert_eq(next.tidx, tidx_end.clone()); @@ -225,132 +259,106 @@ where let is_not_dummy = AB::Expr::ONE - local.is_dummy; let is_non_root_layer = local.is_enabled * (AB::Expr::ONE - local.is_first); - let lookup_enable = local.is_enabled * is_not_dummy.clone(); - self.air_shape_bus.lookup_key( - builder, - local.proof_idx, - AirShapeBusMessage { - sort_idx: local.idx.into(), - property_idx: AirShapeProperty::NumRead.to_field(), - value: local.num_read_count.into(), - }, - lookup_enable.clone(), - ); - self.air_shape_bus.lookup_key( - builder, - local.proof_idx, - AirShapeBusMessage { - sort_idx: local.idx.into(), - property_idx: AirShapeProperty::NumWrite.to_field(), - value: local.num_write_count.into(), - }, - lookup_enable.clone(), - ); - self.air_shape_bus.lookup_key( - builder, - local.proof_idx, - AirShapeBusMessage { - sort_idx: local.idx.into(), - property_idx: AirShapeProperty::NumLk.to_field(), - value: local.num_logup_count.into(), - }, - lookup_enable.clone(), - ); - + let active_non_dummy = local.is_enabled * is_not_dummy.clone(); let tidx_for_claims = tidx_after_sumcheck.clone(); + let prod_eval_span = AB::Expr::from_usize(2 * D_EF); + let read_tidx = tidx_for_claims.clone(); + let write_tidx = read_tidx.clone() + local.num_read_count * prod_eval_span.clone(); + let logup_tidx = write_tidx.clone() + local.num_write_count * prod_eval_span; + let read_claim_mult = active_non_dummy.clone() * local.num_read_count; + let write_claim_mult = active_non_dummy.clone() * local.num_write_count; + let logup_claim_mult = active_non_dummy.clone() * local.num_logup_count; + let lambda_one = { + let mut arr = core::array::from_fn(|_| AB::Expr::ZERO); + arr[0] = AB::Expr::ONE; + arr + }; self.prod_read_claim_input_bus.send( builder, local.proof_idx, - TowerProdLayerChallengeMessage { - idx: local.idx.into(), + TowerProdLayerInputMessage { + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - tidx: tidx_for_claims.clone(), - lambda: local.lambda.map(Into::into), - lambda_prime: local.lambda_prime.map(Into::into), + tidx: read_tidx, + lambda_next: local.lambda.map(Into::into), + lambda_cur: local.lambda_prime.map(Into::into), mu: local.mu.map(Into::into), + prod_offset: AB::Expr::ZERO, + lambda_next_start: lambda_one.clone(), + lambda_cur_start: lambda_one.clone(), + num_prod_count: local.num_read_count.into(), }, - is_not_dummy.clone(), + read_claim_mult.clone(), ); - // TODO separate lambda, lambda_prime for prod-write the relation should be local.lambda^(num_read) self.prod_write_claim_input_bus.send( builder, local.proof_idx, - TowerProdLayerChallengeMessage { - idx: local.idx.into(), + TowerProdLayerInputMessage { + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - tidx: tidx_for_claims.clone(), - lambda: local.lambda.map(Into::into), - lambda_prime: local.lambda_prime.map(Into::into), + tidx: write_tidx, + lambda_next: local.lambda.map(Into::into), + lambda_cur: local.lambda_prime.map(Into::into), mu: local.mu.map(Into::into), + prod_offset: local.num_read_count.into(), + lambda_next_start: local.read_lambda_end.map(Into::into), + lambda_cur_start: local.read_lambda_prime_end.map(Into::into), + num_prod_count: local.num_write_count.into(), }, - is_not_dummy.clone(), + write_claim_mult.clone(), ); - // TODO separate lambda, lambda_prime for logup the relation should be local.lambda^(num_read + num_write) self.logup_claim_input_bus.send( builder, local.proof_idx, TowerLogupLayerChallengeMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - tidx: tidx_for_claims.clone(), - lambda: local.lambda.map(Into::into), - lambda_prime: local.lambda_prime.map(Into::into), + tidx: logup_tidx, + lambda_next: local.lambda.map(Into::into), + lambda_cur: local.lambda_prime.map(Into::into), mu: local.mu.map(Into::into), + lambda_next_start: local.write_lambda_end.map(Into::into), + lambda_cur_start: local.write_lambda_prime_end.map(Into::into), + num_logup_count: local.num_logup_count.into(), }, - is_not_dummy.clone(), + logup_claim_mult.clone(), ); self.prod_read_claim_bus.receive( builder, local.proof_idx, TowerProdSumClaimMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - lambda_claim: local.read_claim.map(Into::into), - lambda_prime_claim: local.read_claim_prime.map(Into::into), - num_prod_count: local.num_read_count.into(), + lambda_next_claim: local.read_claim.map(Into::into), + lambda_cur_claim: local.read_claim_prime.map(Into::into), + lambda_next_end: local.read_lambda_end.map(Into::into), + lambda_cur_end: local.read_lambda_prime_end.map(Into::into), }, - is_not_dummy.clone(), + read_claim_mult, ); self.prod_write_claim_bus.receive( builder, local.proof_idx, TowerProdSumClaimMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - lambda_claim: local.write_claim.map(Into::into), - lambda_prime_claim: local.write_claim_prime.map(Into::into), - num_prod_count: local.num_write_count.into(), + lambda_next_claim: local.write_claim.map(Into::into), + lambda_cur_claim: local.write_claim_prime.map(Into::into), + lambda_next_end: local.write_lambda_end.map(Into::into), + lambda_cur_end: local.write_lambda_prime_end.map(Into::into), }, - is_not_dummy.clone(), + write_claim_mult, ); self.logup_claim_bus.receive( builder, local.proof_idx, TowerLogupClaimMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - lambda_claim: local.logup_claim.map(Into::into), - lambda_prime_claim: local.logup_claim_prime.map(Into::into), - num_logup_count: local.num_logup_count.into(), + lambda_next_claim: local.logup_claim.map(Into::into), + lambda_cur_claim: local.logup_claim_prime.map(Into::into), }, - is_not_dummy.clone(), - ); - - let root_layer_mask = local.is_first * is_not_dummy.clone(); - assert_array_eq( - &mut builder.when(root_layer_mask.clone()), - local.read_claim_prime, - local.r0_claim, - ); - assert_array_eq( - &mut builder.when(root_layer_mask.clone()), - local.write_claim_prime, - local.w0_claim, - ); - assert_array_eq( - &mut builder.when(root_layer_mask), - local.logup_claim_prime, - local.q0_claim, + logup_claim_mult, ); // 1. TowerLayerInputBus @@ -359,13 +367,15 @@ where builder, local.proof_idx, TowerLayerInputMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), tidx: local.tidx.into(), - r0_claim: local.r0_claim.map(Into::into), - w0_claim: local.w0_claim.map(Into::into), - q0_claim: local.q0_claim.map(Into::into), + num_layers: local.num_layers.into(), + num_read_specs: local.num_read_count.into(), + num_write_specs: local.num_write_count.into(), + num_logup_specs: local.num_logup_count.into(), + initial_tower_claim: local.initial_tower_claim.map(Into::into), }, - local.is_first_air_idx * is_not_dummy.clone(), + local.is_first * active_non_dummy.clone(), ); // 2. TowerLayerOutputBus // 2a. Send GKR input layer claims back @@ -373,14 +383,14 @@ where builder, local.proof_idx, TowerLayerOutputMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), tidx: tidx_end, layer_idx_end: local.layer_idx.into(), input_layer_claim: folded_claim.map(Into::into), - lambda: local.lambda.map(Into::into), + lambda_next: local.lambda.map(Into::into), mu: local.mu.map(Into::into), }, - is_last.clone() * is_not_dummy.clone(), + is_last.clone() * active_non_dummy.clone(), ); // 3. TowerSumcheckInputBus // 3a. Send claim to sumcheck @@ -389,10 +399,10 @@ where builder, local.proof_idx, TowerSumcheckInputMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), is_last_layer: is_last.clone(), - tidx: local.tidx + AB::Expr::from_usize(ALPHA_LEN + SUMCHECK_INIT_LEN), + tidx: local.tidx + AB::Expr::from_usize(SUMCHECK_INIT_LEN), claim: local.sumcheck_claim_in.map(Into::into), }, is_non_root_layer.clone() * is_not_dummy.clone(), @@ -408,7 +418,7 @@ where builder, local.proof_idx, TowerSumcheckOutputMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), tidx: tidx_after_sumcheck.clone(), claim_out: sumcheck_claim_out.map(Into::into), @@ -422,12 +432,12 @@ where builder, local.proof_idx, TowerSumcheckChallengeMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - sumcheck_round: AB::Expr::ZERO, + sumcheck_round: local.layer_idx.into(), challenge: local.mu.map(Into::into), }, - is_transition.clone() * is_not_dummy.clone(), + is_transition.clone() * active_non_dummy, ); /////////////////////////////////////////////////////////////////////// @@ -435,27 +445,115 @@ where /////////////////////////////////////////////////////////////////////// // 1. TranscriptBus - // sample lambda and mu - // in root & intermediate layer: for next.sumcheck_claim_in - // in last layer: for send back to GKR input layer - // 1a. Sample `lambda` — only on non-root layers. - // Root layer uses alpha_logup (set in trace), not a transcript sample. + let root_lambda_label_tidx = local.tidx + out_eval_span.clone(); + for (i, value) in LABEL_COMBINE_VALUES.iter().enumerate() { + self.transcript_bus.observe( + builder, + local.proof_idx, + root_lambda_label_tidx.clone() + AB::Expr::from_usize(i), + AB::Expr::from_usize(*value), + local.is_enabled * local.is_first * is_not_dummy.clone(), + ); + } + let root_lambda_tidx = root_lambda_label_tidx.clone() + AB::Expr::from_usize(LABEL_COMBINE); + + let non_root_init_tidx = local.tidx; + self.transcript_bus.observe( + builder, + local.proof_idx, + non_root_init_tidx, + local.layer_idx, + is_non_root_layer.clone() * is_not_dummy.clone(), + ); + self.transcript_bus.observe( + builder, + local.proof_idx, + local.tidx + AB::Expr::ONE, + AB::Expr::ZERO, + is_non_root_layer.clone() * is_not_dummy.clone(), + ); + self.transcript_bus.observe( + builder, + local.proof_idx, + local.tidx + AB::Expr::from_usize(2), + AB::Expr::from_usize(3), + is_non_root_layer.clone() * is_not_dummy.clone(), + ); + self.transcript_bus.observe( + builder, + local.proof_idx, + local.tidx + AB::Expr::from_usize(3), + AB::Expr::ZERO, + is_non_root_layer.clone() * is_not_dummy.clone(), + ); + + let root_mu_label_tidx = root_lambda_tidx.clone() + AB::Expr::from_usize(D_EF); + for (i, value) in LABEL_PRODUCT_SUM_VALUES.iter().enumerate() { + self.transcript_bus.observe( + builder, + local.proof_idx, + root_mu_label_tidx.clone() + AB::Expr::from_usize(i), + AB::Expr::from_usize(*value), + local.is_enabled * local.is_first * is_not_dummy.clone(), + ); + } + let root_mu_tidx = root_mu_label_tidx + AB::Expr::from_usize(LABEL_PRODUCT_SUM); + + let non_root_mu_label_tidx = tidx_after_sumcheck.clone() + out_eval_span; + for (i, value) in LABEL_MERGE_VALUES.iter().enumerate() { + self.transcript_bus.observe( + builder, + local.proof_idx, + non_root_mu_label_tidx.clone() + AB::Expr::from_usize(i), + AB::Expr::from_usize(*value), + is_non_root_layer.clone() * is_not_dummy.clone(), + ); + } + let non_root_mu_tidx = non_root_mu_label_tidx + AB::Expr::from_usize(LABEL_MERGE); + + let non_root_lambda_label_tidx = non_root_mu_tidx.clone() + AB::Expr::from_usize(D_EF); + for (i, value) in LABEL_COMBINE_VALUES.iter().enumerate() { + self.transcript_bus.observe( + builder, + local.proof_idx, + non_root_lambda_label_tidx.clone() + AB::Expr::from_usize(i), + AB::Expr::from_usize(*value), + is_non_root_layer.clone() * is_not_dummy.clone(), + ); + } + let non_root_lambda_tidx = non_root_lambda_label_tidx + AB::Expr::from_usize(LABEL_COMBINE); + + // 1a. Sample `lambda`: root lambda_1 after root out-evals, later + // rows sample lambda_next after merge. self.transcript_bus.sample_ext( builder, local.proof_idx, - local.tidx, + root_lambda_tidx, + local.lambda, + local.is_enabled * local.is_first * is_not_dummy.clone(), + ); + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + non_root_lambda_tidx, local.lambda, is_non_root_layer.clone() * is_not_dummy.clone(), ); - // 1b. Observe layer claims - let tidx = tidx_after_sumcheck; - // 1c. Sample `mu` + // 1b. Sample `mu`: root r_1 after product_sum; later rows after + // child-claim observations and merge label. + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + root_mu_tidx, + local.mu, + local.is_enabled * local.is_first * is_not_dummy.clone(), + ); self.transcript_bus.sample_ext( builder, local.proof_idx, - tidx, + non_root_mu_tidx, local.mu, - local.is_enabled * is_not_dummy.clone(), + is_non_root_layer * is_not_dummy, ); } } diff --git a/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs b/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs index 0c0fef221..32d0bf80c 100644 --- a/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs +++ b/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs @@ -11,12 +11,13 @@ use p3_matrix::Matrix; use stark_recursion_circuit_derive::AlignedBorrow; use crate::tower::bus::{ - TowerLogupClaimBus, TowerLogupClaimInputBus, TowerLogupClaimMessage, - TowerLogupLayerChallengeMessage, + TowerLogupClaimBus, TowerLogupClaimInputBus, TowerLogupClaimMessage, TowerLogupInitBus, + TowerLogupInitMessage, TowerLogupLayerChallengeMessage, TowerLogupRootBus, + TowerLogupRootInputBus, TowerLogupRootInputMessage, TowerLogupRootMessage, }; use recursion_circuit::{ bus::TranscriptBus, - utils::{assert_zeros, ext_field_add, ext_field_multiply, ext_field_subtract}, + utils::{assert_one_ext, assert_zeros, ext_field_add, ext_field_multiply, ext_field_subtract}, }; #[repr(C)] @@ -25,9 +26,11 @@ pub struct TowerLogupSumCheckClaimCols { pub is_enabled: T, pub proof_idx: T, pub idx: T, + pub chip_id: T, pub is_first_layer: T, pub is_first: T, pub is_dummy: T, + pub is_root_layer: T, pub layer_idx: T, pub index_id: T, @@ -49,6 +52,8 @@ pub struct TowerLogupSumCheckClaimCols { pub acc_sum: [T; D_EF], pub acc_p_cross: [T; D_EF], pub acc_q_cross: [T; D_EF], + pub root_p_acc: [T; D_EF], + pub root_q_acc: [T; D_EF], pub num_logup_count: T, } @@ -56,6 +61,9 @@ pub struct TowerLogupSumCheckClaimAir { pub transcript_bus: TranscriptBus, pub logup_claim_input_bus: TowerLogupClaimInputBus, pub logup_claim_bus: TowerLogupClaimBus, + pub root_input_bus: TowerLogupRootInputBus, + pub root_bus: TowerLogupRootBus, + pub init_bus: TowerLogupInitBus, } impl BaseAir for TowerLogupSumCheckClaimAir { @@ -83,6 +91,10 @@ where builder.assert_bool(local.is_dummy); builder.assert_bool(local.is_first_layer); + builder.assert_bool(local.is_root_layer); + builder + .when(local.is_root_layer) + .assert_zero(local.layer_idx); /////////////////////////////////////////////////////////////////////// // Structural constraints (replaces NestedForLoopSubAir<2>) @@ -174,40 +186,64 @@ where .when(local.is_enabled * next.is_enabled * next.is_first_layer) .assert_zero(next.index_id); builder - .when(is_within_layer.clone() * is_not_dummy.clone()) + .when(is_within_layer.clone()) .assert_eq(next.index_id, local.index_id + AB::Expr::ONE); builder - .when(is_layer_end.clone() * is_not_dummy.clone()) + .when(is_layer_end.clone() * local.num_logup_count) .assert_eq(local.index_id + AB::Expr::ONE, local.num_logup_count); assert_zeros( - &mut builder.when(local.is_first * is_not_dummy.clone()), + &mut builder.when(local.is_first), local.acc_sum.map(Into::into), ); assert_zeros( - &mut builder.when(local.is_first * is_not_dummy.clone()), + &mut builder.when(local.is_first), local.acc_p_cross.map(Into::into), ); assert_zeros( - &mut builder.when(local.is_first * is_not_dummy.clone()), + &mut builder.when(local.is_first), local.acc_q_cross.map(Into::into), ); - builder - .when(local.is_first * is_not_dummy.clone()) - .assert_eq(local.pow_lambda[0], AB::Expr::ONE); - for limb in local.pow_lambda.iter().copied().skip(1) { - builder - .when(local.is_first * is_not_dummy.clone()) - .assert_zero(limb); - } - builder - .when(local.is_first * is_not_dummy.clone()) - .assert_eq(local.pow_lambda_prime[0], AB::Expr::ONE); - for limb in local.pow_lambda_prime.iter().copied().skip(1) { - builder - .when(local.is_first * is_not_dummy.clone()) - .assert_zero(limb); - } + assert_zeros( + &mut builder.when(local.is_first * local.is_root_layer), + local.root_p_acc.map(Into::into), + ); + assert_one_ext( + &mut builder.when(local.is_first * local.is_root_layer), + local.root_q_acc, + ); + assert_zeros( + &mut builder.when(local.is_first * (AB::Expr::ONE - local.is_root_layer)), + local.root_p_acc.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_first * (AB::Expr::ONE - local.is_root_layer)), + local.root_q_acc.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_dummy), + local.p_xi_0.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_dummy), + local.p_xi_1.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_dummy), + local.q_xi_0.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_dummy), + local.q_xi_1.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_dummy), + local.p_xi.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_dummy), + local.q_xi.map(Into::into), + ); let delta_p = ext_field_subtract::(local.p_xi_1, local.p_xi_0); let expected_p_xi = @@ -221,6 +257,12 @@ where let (p_cross_term, q_cross_term) = compute_recursive_relations(local.p_xi_0, local.q_xi_0, local.p_xi_1, local.q_xi_1); + let root_p_with_cur = ext_field_add::( + ext_field_multiply::(local.root_p_acc, q_cross_term.clone()), + ext_field_multiply::(p_cross_term.clone(), local.root_q_acc), + ); + let root_q_with_cur = + ext_field_multiply::(local.root_q_acc, q_cross_term.clone()); let lambda = local.lambda.map(Into::into); let pow_lambda = local.pow_lambda.map(Into::into); @@ -237,7 +279,8 @@ where next.acc_sum, acc_sum_with_cur, ); - let pow_lambda_next = ext_field_multiply::(pow_lambda, lambda.clone()); + let lambda_sq = ext_field_multiply::(lambda.clone(), lambda.clone()); + let pow_lambda_next = ext_field_multiply::(pow_lambda, lambda_sq); assert_array_eq( &mut builder.when(is_within_layer.clone()), next.pow_lambda, @@ -265,8 +308,20 @@ where next.acc_q_cross, acc_q_with_cur.clone(), ); + assert_array_eq( + &mut builder.when(is_within_layer.clone()), + next.root_p_acc, + root_p_with_cur.clone(), + ); + assert_array_eq( + &mut builder.when(is_within_layer.clone()), + next.root_q_acc, + root_q_with_cur.clone(), + ); + let lambda_prime_sq = + ext_field_multiply::(lambda_prime.clone(), lambda_prime.clone()); let pow_lambda_prime_next = - ext_field_multiply::(pow_lambda_prime, lambda_prime.clone()); + ext_field_multiply::(pow_lambda_prime, lambda_prime_sq); assert_array_eq( &mut builder.when(is_within_layer.clone()), next.pow_lambda_prime, @@ -277,31 +332,67 @@ where builder, local.proof_idx, TowerLogupLayerChallengeMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), tidx: local.tidx.into(), - lambda: lambda.clone(), - lambda_prime: lambda_prime.clone(), + lambda_next: lambda.clone(), + lambda_cur: lambda_prime.clone(), mu: local.mu.map(Into::into), + lambda_next_start: local.pow_lambda.map(Into::into), + lambda_cur_start: local.pow_lambda_prime.map(Into::into), + num_logup_count: local.num_logup_count.into(), }, - local.is_first.into(), + local.is_first * local.is_enabled * local.num_logup_count, ); self.logup_claim_bus.send( builder, local.proof_idx, TowerLogupClaimMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - lambda_claim: acc_sum_export.map(Into::into), - lambda_prime_claim: acc_q_with_cur.map(Into::into), + lambda_next_claim: acc_sum_export.clone().map(Into::into), + lambda_cur_claim: ext_field_add::(acc_p_with_cur, acc_q_with_cur) + .map(Into::into), + }, + is_layer_end.clone() * local.num_logup_count, + ); + + self.root_input_bus.receive( + builder, + local.proof_idx, + TowerLogupRootInputMessage { + chip_id: local.chip_id.into(), + tidx: local.tidx.into(), + lambda_1: lambda, + r_1: local.mu.map(Into::into), + lambda_1_start: local.pow_lambda.map(Into::into), num_logup_count: local.num_logup_count.into(), }, - is_layer_end, + local.is_first * local.is_enabled * local.num_logup_count * local.is_root_layer, + ); + self.root_bus.send( + builder, + local.proof_idx, + TowerLogupRootMessage { + chip_id: local.chip_id.into(), + p0_claim: root_p_with_cur.map(Into::into), + q0_claim: root_q_with_cur.map(Into::into), + }, + is_layer_end.clone() * local.num_logup_count * local.is_root_layer, + ); + self.init_bus.send( + builder, + local.proof_idx, + TowerLogupInitMessage { + chip_id: local.chip_id.into(), + initial_claim: acc_sum_export.map(Into::into), + }, + is_layer_end * local.num_logup_count * local.is_root_layer, ); let mut tidx = local.tidx.into(); - for claim in [local.p_xi_0, local.q_xi_0, local.p_xi_1, local.q_xi_1] { + for claim in [local.p_xi_0, local.p_xi_1, local.q_xi_0, local.q_xi_1] { self.transcript_bus.observe_ext( builder, local.proof_idx, diff --git a/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs b/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs index b1b0d92f0..58957d50a 100644 --- a/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs +++ b/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs @@ -7,7 +7,10 @@ use p3_matrix::dense::RowMajorMatrix; use super::TowerLogupSumCheckClaimCols; use crate::{ - tower::{TowerTowerEvalRecord, interpolate_pair, layer::trace::TowerLayerRecord}, + tower::{ + TowerTowerEvalRecord, interpolate_pair, + layer::trace::{TowerLayerRecord, ext_pow}, + }, tracegen::RowMajorChip, }; @@ -20,13 +23,10 @@ type LogupTraceCtx<'a> = ( ); fn logup_rows_for_record(record: &TowerLayerRecord) -> usize { - if record.layer_count() == 0 { - 1 - } else { - (0..record.layer_count()) - .map(|layer_idx| record.logup_count_at(layer_idx).max(1)) - .sum() - } + (0..record.layer_count()) + .map(|layer_idx| record.logup_count_at(layer_idx)) + .sum::() + .max(1) } impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { @@ -69,7 +69,15 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { .zip(mus_records.par_iter()), ) .for_each(|(chunk, ((record, tower), mus_for_proof))| { - if record.layer_count() == 0 { + if chunk.is_empty() { + return; + } + + let active_row_count = (0..record.layer_count()) + .map(|layer_idx| record.logup_count_at(layer_idx)) + .sum::(); + + if active_row_count == 0 { debug_assert_eq!(chunk.len(), width); let row_data = &mut chunk[..width]; let cols: &mut TowerLogupSumCheckClaimCols = row_data.borrow_mut(); @@ -77,8 +85,10 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { cols.is_first_layer = F::from_bool(record.is_first_air_idx); cols.is_first = F::ONE; // single row = first of its (degenerate) layer cols.is_dummy = F::ONE; + cols.is_root_layer = F::ONE; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.layer_idx = F::ZERO; cols.index_id = F::ZERO; cols.tidx = F::from_usize(record.tidx); @@ -98,7 +108,9 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { cols.acc_sum = [F::ZERO; D_EF]; cols.acc_p_cross = [F::ZERO; D_EF]; cols.acc_q_cross = [F::ZERO; D_EF]; - cols.num_logup_count = F::ONE; + cols.root_p_acc = [F::ZERO; D_EF]; + cols.root_q_acc = lambda_prime_one; + cols.num_logup_count = F::ZERO; return; } @@ -111,9 +123,14 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { .get(layer_idx) .map(|rows| rows.as_slice()) .unwrap_or(&[]); - let total_rows = record.logup_count_at(layer_idx).max(1); + let logup_active = tower + .logup_active + .get(layer_idx) + .map(|rows| rows.as_slice()) + .unwrap_or(&[]); + let total_rows = record.logup_count_at(layer_idx); debug_assert!( - total_rows == logup_rows.len().max(1), + total_rows == logup_rows.len(), "unexpected logup count mismatch at layer {layer_idx}" ); @@ -127,20 +144,24 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { .try_into() .unwrap(); let mu_basis: [F; D_EF] = mu.as_basis_coefficients_slice().try_into().unwrap(); - let layer_tidx = record.claim_tidx(layer_idx); + let group_offset = + record.read_count_at(layer_idx) + record.write_count_at(layer_idx); + let layer_tidx = record.claim_tidx(layer_idx) + group_offset * 2 * D_EF; - let mut pow_lambda = EF::ONE; - let mut pow_lambda_prime = EF::ONE; + let mut pow_lambda = ext_pow(lambda, group_offset); + let mut pow_lambda_prime = ext_pow(lambda_prime, group_offset); let mut acc_sum = EF::ZERO; let mut acc_p_cross = EF::ZERO; let mut acc_q_cross = EF::ZERO; + let mut root_p_acc = EF::ZERO; + let mut root_q_acc = if layer_idx == 0 { EF::ONE } else { EF::ZERO }; for row_in_layer in 0..total_rows { let row = chunk_iter .next() .expect("chunk should have enough rows for layer"); let cols: &mut TowerLogupSumCheckClaimCols = row.borrow_mut(); - let is_real = row_in_layer < logup_rows.len(); + let is_real = logup_active.get(row_in_layer).copied().unwrap_or(false); let quad = if is_real { logup_rows[row_in_layer] } else { @@ -176,6 +197,7 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { cols.is_enabled = F::ONE; cols.is_dummy = F::from_bool(!is_real); + cols.is_root_layer = F::from_bool(layer_idx == 0); let is_first_row_of_layer = row_in_layer == 0; let is_first_row_of_record = proof_row_idx == 0; cols.is_first_layer = @@ -183,6 +205,7 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { cols.is_first = F::from_bool(is_first_row_of_layer); cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.layer_idx = F::from_usize(layer_idx); cols.index_id = F::from_usize(row_in_layer); cols.tidx = F::from_usize(layer_tidx + row_in_layer * 4 * D_EF); @@ -210,13 +233,23 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { .as_basis_coefficients_slice() .try_into() .unwrap(); + cols.root_p_acc = + root_p_acc.as_basis_coefficients_slice().try_into().unwrap(); + cols.root_q_acc = + root_q_acc.as_basis_coefficients_slice().try_into().unwrap(); cols.num_logup_count = F::from_usize(total_rows); acc_sum += contribution; acc_p_cross += p_cross_contribution; acc_q_cross += q_cross_contribution; - pow_lambda *= lambda; - pow_lambda_prime *= lambda_prime; + if is_real { + let next_root_p = root_p_acc * q_cross + p_cross * root_q_acc; + let next_root_q = root_q_acc * q_cross; + root_p_acc = next_root_p; + root_q_acc = next_root_q; + } + pow_lambda *= lambda * lambda; + pow_lambda_prime *= lambda_prime * lambda_prime; proof_row_idx += 1; } diff --git a/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs b/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs index 27ad5a1de..474695bed 100644 --- a/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs +++ b/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs @@ -11,12 +11,15 @@ use p3_matrix::Matrix; use stark_recursion_circuit_derive::AlignedBorrow; use crate::tower::bus::{ - TowerProdLayerChallengeMessage, TowerProdReadClaimBus, TowerProdReadClaimInputBus, + TowerProdInitMessage, TowerProdLayerInputMessage, TowerProdReadClaimBus, + TowerProdReadClaimInputBus, TowerProdRootInputMessage, TowerProdRootMessage, TowerProdSumClaimMessage, TowerProdWriteClaimBus, TowerProdWriteClaimInputBus, + TowerReadInitBus, TowerReadRootBus, TowerReadRootInputBus, TowerWriteInitBus, + TowerWriteRootBus, TowerWriteRootInputBus, }; use recursion_circuit::{ bus::TranscriptBus, - utils::{assert_zeros, ext_field_add, ext_field_multiply, ext_field_subtract}, + utils::{assert_one_ext, assert_zeros, ext_field_add, ext_field_multiply, ext_field_subtract}, }; #[repr(C)] @@ -25,12 +28,15 @@ pub struct TowerProdSumCheckClaimCols { pub is_enabled: T, pub proof_idx: T, pub idx: T, + pub chip_id: T, pub is_first_layer: T, pub is_first: T, pub is_dummy: T, + pub is_root_layer: T, pub layer_idx: T, pub index_id: T, + pub prod_offset: T, pub tidx: T, pub lambda: [T; D_EF], @@ -43,43 +49,68 @@ pub struct TowerProdSumCheckClaimCols { pub pow_lambda_prime: [T; D_EF], pub acc_sum: [T; D_EF], pub acc_sum_prime: [T; D_EF], + pub root_output_acc: [T; D_EF], pub num_prod_count: T, } -pub struct TowerProdSumCheckClaimAir { +pub struct TowerProdSumCheckClaimAir { pub transcript_bus: TranscriptBus, pub prod_claim_input_bus: IB, pub prod_claim_bus: OB, + pub root_input_bus: RIB, + pub root_bus: ROB, + pub init_bus: INITB, } -pub type TowerProdReadSumCheckClaimAir = - TowerProdSumCheckClaimAir; -pub type TowerProdWriteSumCheckClaimAir = - TowerProdSumCheckClaimAir; - -impl BaseAir for TowerProdSumCheckClaimAir { +pub type TowerProdReadSumCheckClaimAir = TowerProdSumCheckClaimAir< + TowerProdReadClaimInputBus, + TowerProdReadClaimBus, + TowerReadRootInputBus, + TowerReadRootBus, + TowerReadInitBus, +>; +pub type TowerProdWriteSumCheckClaimAir = TowerProdSumCheckClaimAir< + TowerProdWriteClaimInputBus, + TowerProdWriteClaimBus, + TowerWriteRootInputBus, + TowerWriteRootBus, + TowerWriteInitBus, +>; + +impl BaseAir + for TowerProdSumCheckClaimAir +{ fn width(&self) -> usize { TowerProdSumCheckClaimCols::::width() } } -impl BaseAirWithPublicValues - for TowerProdSumCheckClaimAir +impl BaseAirWithPublicValues + for TowerProdSumCheckClaimAir +{ +} +impl PartitionedBaseAir + for TowerProdSumCheckClaimAir { } -impl PartitionedBaseAir for TowerProdSumCheckClaimAir {} -impl TowerProdSumCheckClaimAir { - fn eval_core( +impl TowerProdSumCheckClaimAir { + fn eval_core( &self, builder: &mut AB, - mut recv_challenge: Recv, - mut send_claim: Send, + mut recv_challenge: RecvLayer, + mut send_claim: SendLayer, + mut recv_root: RecvRoot, + mut send_root: SendRoot, + mut send_init: SendInit, ) where AB: AirBuilder + InteractionBuilder, ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, - Recv: FnMut(&IB, &mut AB, AB::Var, TowerProdLayerChallengeMessage, AB::Expr), - Send: FnMut(&OB, &mut AB, AB::Var, TowerProdSumClaimMessage, AB::Expr), + RecvLayer: FnMut(&IB, &mut AB, AB::Var, TowerProdLayerInputMessage, AB::Expr), + SendLayer: FnMut(&OB, &mut AB, AB::Var, TowerProdSumClaimMessage, AB::Expr), + RecvRoot: FnMut(&RIB, &mut AB, AB::Var, TowerProdRootInputMessage, AB::Expr), + SendRoot: FnMut(&ROB, &mut AB, AB::Var, TowerProdRootMessage, AB::Expr), + SendInit: FnMut(&INITB, &mut AB, AB::Var, TowerProdInitMessage, AB::Expr), { let main = builder.main(); let (local_row, next_row) = ( @@ -91,6 +122,10 @@ impl TowerProdSumCheckClaimAir { builder.assert_bool(local.is_dummy); builder.assert_bool(local.is_first_layer); + builder.assert_bool(local.is_root_layer); + builder + .when(local.is_root_layer) + .assert_zero(local.layer_idx); /////////////////////////////////////////////////////////////////////// // Structural constraints (replaces NestedForLoopSubAir<2>) @@ -193,36 +228,40 @@ impl TowerProdSumCheckClaimAir { .when(local.is_enabled * next.is_enabled * next.is_first_layer) .assert_zero(next.index_id); builder - .when(is_within_layer.clone() * is_not_dummy.clone()) + .when(is_within_layer.clone()) .assert_eq(next.index_id, local.index_id + AB::Expr::ONE); builder - .when(is_layer_end.clone() * is_not_dummy.clone()) + .when(is_layer_end.clone() * local.num_prod_count) .assert_eq(local.index_id + AB::Expr::ONE, local.num_prod_count); assert_zeros( - &mut builder.when(local.is_first * is_not_dummy.clone()), + &mut builder.when(local.is_first), local.acc_sum.map(Into::into), ); assert_zeros( - &mut builder.when(local.is_first * is_not_dummy.clone()), + &mut builder.when(local.is_first), local.acc_sum_prime.map(Into::into), ); - builder - .when(local.is_first * is_not_dummy.clone()) - .assert_eq(local.pow_lambda[0], AB::Expr::ONE); - for limb in local.pow_lambda.iter().copied().skip(1) { - builder - .when(local.is_first * is_not_dummy.clone()) - .assert_zero(limb); - } - builder - .when(local.is_first * is_not_dummy.clone()) - .assert_eq(local.pow_lambda_prime[0], AB::Expr::ONE); - for limb in local.pow_lambda_prime.iter().copied().skip(1) { - builder - .when(local.is_first * is_not_dummy.clone()) - .assert_zero(limb); - } + assert_one_ext( + &mut builder.when(local.is_first * local.is_root_layer), + local.root_output_acc, + ); + assert_zeros( + &mut builder.when(local.is_first * (AB::Expr::ONE - local.is_root_layer)), + local.root_output_acc.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_dummy), + local.p_xi_0.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_dummy), + local.p_xi_1.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_dummy), + local.p_xi.map(Into::into), + ); let delta = ext_field_subtract::(local.p_xi_1, local.p_xi_0); let expected_p_xi = @@ -235,6 +274,8 @@ impl TowerProdSumCheckClaimAir { let acc_sum_export = acc_sum_with_cur.clone(); let prime_product = ext_field_multiply::(local.p_xi_0, local.p_xi_1); + let root_output_with_cur = + ext_field_multiply::(local.root_output_acc, prime_product.clone()); let pow_lambda_prime = local.pow_lambda_prime.map(Into::into); let prime_contribution = ext_field_multiply::(pow_lambda_prime.clone(), prime_product); @@ -252,9 +293,15 @@ impl TowerProdSumCheckClaimAir { next.acc_sum_prime, acc_sum_prime_with_cur, ); + assert_array_eq( + &mut builder.when(is_within_layer.clone()), + next.root_output_acc, + root_output_with_cur.clone(), + ); let lambda = local.lambda.map(Into::into); let pow_lambda_next = ext_field_multiply::(pow_lambda, lambda.clone()); + let lambda_end = pow_lambda_next.clone(); assert_array_eq( &mut builder.when(is_within_layer.clone()), next.pow_lambda, @@ -263,6 +310,7 @@ impl TowerProdSumCheckClaimAir { let lambda_prime = local.lambda_prime.map(Into::into); let pow_lambda_prime_next = ext_field_multiply::(pow_lambda_prime, lambda_prime.clone()); + let lambda_prime_end = pow_lambda_prime_next.clone(); assert_array_eq( &mut builder.when(is_within_layer.clone()), next.pow_lambda_prime, @@ -273,15 +321,19 @@ impl TowerProdSumCheckClaimAir { &self.prod_claim_input_bus, builder, local.proof_idx, - TowerProdLayerChallengeMessage { - idx: local.idx.into(), + TowerProdLayerInputMessage { + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), tidx: local.tidx.into(), - lambda, - lambda_prime: lambda_prime.clone(), + lambda_next: lambda.clone(), + lambda_cur: lambda_prime.clone(), mu: local.mu.map(Into::into), + prod_offset: local.prod_offset.into(), + lambda_next_start: local.pow_lambda.map(Into::into), + lambda_cur_start: local.pow_lambda_prime.map(Into::into), + num_prod_count: local.num_prod_count.into(), }, - local.is_first.into(), + local.is_first * local.is_enabled * local.num_prod_count, ); send_claim( @@ -289,13 +341,49 @@ impl TowerProdSumCheckClaimAir { builder, local.proof_idx, TowerProdSumClaimMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - lambda_claim: acc_sum_export.map(Into::into), - lambda_prime_claim: acc_sum_prime_export.map(Into::into), + lambda_next_claim: acc_sum_export.clone().map(Into::into), + lambda_cur_claim: acc_sum_prime_export.map(Into::into), + lambda_next_end: lambda_end.map(Into::into), + lambda_cur_end: lambda_prime_end.map(Into::into), + }, + is_layer_end.clone() * local.num_prod_count, + ); + + recv_root( + &self.root_input_bus, + builder, + local.proof_idx, + TowerProdRootInputMessage { + chip_id: local.chip_id.into(), + tidx: local.tidx.into(), + lambda_1: lambda, + r_1: local.mu.map(Into::into), + lambda_1_start: local.pow_lambda.map(Into::into), num_prod_count: local.num_prod_count.into(), }, - is_layer_end, + local.is_first * local.is_enabled * local.num_prod_count * local.is_root_layer, + ); + send_root( + &self.root_bus, + builder, + local.proof_idx, + TowerProdRootMessage { + chip_id: local.chip_id.into(), + output_claim: root_output_with_cur.map(Into::into), + }, + is_layer_end.clone() * local.num_prod_count * local.is_root_layer, + ); + send_init( + &self.init_bus, + builder, + local.proof_idx, + TowerProdInitMessage { + chip_id: local.chip_id.into(), + initial_claim: acc_sum_export.map(Into::into), + }, + is_layer_end * local.num_prod_count * local.is_root_layer, ); let mut tidx = local.tidx.into(); @@ -333,6 +421,15 @@ macro_rules! impl_prod_sum_air { |bus, builder, proof_idx, msg, mult| { bus.send(builder, proof_idx, msg, mult); }, + |bus, builder, proof_idx, msg, mult| { + bus.receive(builder, proof_idx, msg, mult); + }, + |bus, builder, proof_idx, msg, mult| { + bus.send(builder, proof_idx, msg, mult); + }, + |bus, builder, proof_idx, msg, mult| { + bus.send(builder, proof_idx, msg, mult); + }, ); } } diff --git a/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs b/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs index 8253341dd..c03fa5ddf 100644 --- a/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs +++ b/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs @@ -7,7 +7,10 @@ use p3_matrix::dense::RowMajorMatrix; use super::TowerProdSumCheckClaimCols; use crate::{ - tower::{TowerTowerEvalRecord, interpolate_pair, layer::trace::TowerLayerRecord}, + tower::{ + TowerTowerEvalRecord, interpolate_pair, + layer::trace::{TowerLayerRecord, ext_pow}, + }, tracegen::RowMajorChip, }; @@ -21,19 +24,16 @@ type ProdTraceCtx<'a> = ( ); fn prod_rows_for_record(record: &TowerLayerRecord, is_write: bool) -> usize { - if record.layer_count() == 0 { - 1 - } else { - (0..record.layer_count()) - .map(|layer_idx| { - if is_write { - record.write_count_at(layer_idx).max(1) - } else { - record.read_count_at(layer_idx).max(1) - } - }) - .sum() - } + (0..record.layer_count()) + .map(|layer_idx| { + if is_write { + record.write_count_at(layer_idx) + } else { + record.read_count_at(layer_idx) + } + }) + .sum::() + .max(1) } #[allow(clippy::too_many_arguments)] @@ -77,7 +77,21 @@ fn generate_prod_trace( .zip(mus_records.par_iter()), ) .for_each(|(chunk, ((record, tower), mus_for_proof))| { - if record.layer_count() == 0 { + if chunk.is_empty() { + return; + } + + let active_row_count = (0..record.layer_count()) + .map(|layer_idx| { + if is_write { + record.write_count_at(layer_idx) + } else { + record.read_count_at(layer_idx) + } + }) + .sum::(); + + if active_row_count == 0 { debug_assert_eq!(chunk.len(), width); let row_data = &mut chunk[..width]; let cols: &mut TowerProdSumCheckClaimCols = row_data.borrow_mut(); @@ -85,10 +99,13 @@ fn generate_prod_trace( cols.is_first_layer = F::from_bool(record.is_first_air_idx); cols.is_first = F::ONE; // single row = first of its (degenerate) layer cols.is_dummy = F::ONE; + cols.is_root_layer = F::ONE; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.layer_idx = F::ZERO; cols.index_id = F::ZERO; + cols.prod_offset = F::ZERO; cols.tidx = F::from_usize(record.tidx); cols.lambda = [F::ZERO; D_EF]; let mut lambda_prime_one = [F::ZERO; D_EF]; @@ -102,7 +119,8 @@ fn generate_prod_trace( cols.pow_lambda_prime = lambda_prime_one; cols.acc_sum = [F::ZERO; D_EF]; cols.acc_sum_prime = [F::ZERO; D_EF]; - cols.num_prod_count = F::ONE; + cols.root_output_acc = lambda_prime_one; + cols.num_prod_count = F::ZERO; return; } @@ -123,13 +141,26 @@ fn generate_prod_trace( .map(|rows| rows.as_slice()) .unwrap_or(&[]) }; + let active_mask = if is_write { + tower + .write_active + .get(layer_idx) + .map(|rows| rows.as_slice()) + .unwrap_or(&[]) + } else { + tower + .read_active + .get(layer_idx) + .map(|rows| rows.as_slice()) + .unwrap_or(&[]) + }; let total_rows = if is_write { - record.write_count_at(layer_idx).max(1) + record.write_count_at(layer_idx) } else { - record.read_count_at(layer_idx).max(1) + record.read_count_at(layer_idx) }; debug_assert!( - total_rows == active_rows.len().max(1), + total_rows == active_rows.len(), "unexpected prod count mismatch at layer {layer_idx}" ); let lambda = record.lambda_at(layer_idx); @@ -142,19 +173,25 @@ fn generate_prod_trace( .try_into() .unwrap(); let mu_basis: [F; D_EF] = mu.as_basis_coefficients_slice().try_into().unwrap(); - let layer_tidx = record.claim_tidx(layer_idx); + let group_offset = if is_write { + record.read_count_at(layer_idx) + } else { + 0 + }; + let layer_tidx = record.claim_tidx(layer_idx) + group_offset * 2 * D_EF; - let mut pow_lambda = EF::ONE; - let mut pow_lambda_prime = EF::ONE; + let mut pow_lambda = ext_pow(lambda, group_offset); + let mut pow_lambda_prime = ext_pow(lambda_prime, group_offset); let mut acc_sum = EF::ZERO; let mut acc_sum_prime = EF::ZERO; + let mut root_output_acc = if layer_idx == 0 { EF::ONE } else { EF::ZERO }; for row_in_layer in 0..total_rows { let row = chunk_iter .next() .expect("chunk should have enough rows for layer"); let cols: &mut TowerProdSumCheckClaimCols = row.borrow_mut(); - let is_real = row_in_layer < active_rows.len(); + let is_real = active_mask.get(row_in_layer).copied().unwrap_or(false); let pair = if is_real { active_rows[row_in_layer] } else { @@ -173,6 +210,7 @@ fn generate_prod_trace( cols.is_enabled = F::ONE; cols.is_dummy = F::from_bool(!is_real); + cols.is_root_layer = F::from_bool(layer_idx == 0); let is_first_row_of_layer = row_in_layer == 0; let is_first_row_of_record = proof_row_idx == 0; cols.is_first_layer = @@ -180,8 +218,10 @@ fn generate_prod_trace( cols.is_first = F::from_bool(is_first_row_of_layer); cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.layer_idx = F::from_usize(layer_idx); cols.index_id = F::from_usize(row_in_layer); + cols.prod_offset = F::from_usize(group_offset); cols.tidx = F::from_usize(layer_tidx + row_in_layer * 2 * D_EF); cols.lambda = lambda_basis; cols.lambda_prime = lambda_prime_basis; @@ -199,10 +239,17 @@ fn generate_prod_trace( .as_basis_coefficients_slice() .try_into() .unwrap(); + cols.root_output_acc = root_output_acc + .as_basis_coefficients_slice() + .try_into() + .unwrap(); cols.num_prod_count = F::from_usize(total_rows); acc_sum += contribution; acc_sum_prime += prime_contribution; + if is_real { + root_output_acc *= prime_product; + } pow_lambda *= lambda; pow_lambda_prime *= lambda_prime; diff --git a/ceno_recursion_v2/src/tower/layer/trace.rs b/ceno_recursion_v2/src/tower/layer/trace.rs index 1da4f96fe..192d9364e 100644 --- a/ceno_recursion_v2/src/tower/layer/trace.rs +++ b/ceno_recursion_v2/src/tower/layer/trace.rs @@ -13,8 +13,10 @@ use crate::{tower::tower_transcript_len, tracegen::RowMajorChip}; pub struct TowerLayerRecord { pub proof_idx: usize, pub idx: usize, + pub chip_id: usize, pub is_first_air_idx: bool, pub tidx: usize, + pub initial_tower_claim: EF, pub layer_claims: Vec<[EF; 4]>, pub lambdas: Vec, pub eq_at_r_primes: Vec, @@ -111,35 +113,93 @@ impl TowerLayerRecord { #[inline] pub(crate) fn layer_tidx(&self, layer_idx: usize) -> usize { - self.tidx + tower_transcript_len::layers_cumulative(layer_idx) + let mut tidx = self.tidx; + for idx in 0..layer_idx { + tidx += self.layer_span(idx); + } + tidx } #[inline] pub(crate) fn read_count_at(&self, layer_idx: usize) -> usize { - self.read_counts.get(layer_idx).copied().unwrap_or(1) + self.read_counts.get(layer_idx).copied().unwrap_or(0) } #[inline] pub(crate) fn write_count_at(&self, layer_idx: usize) -> usize { - self.write_counts.get(layer_idx).copied().unwrap_or(1) + self.write_counts.get(layer_idx).copied().unwrap_or(0) } #[inline] pub(crate) fn logup_count_at(&self, layer_idx: usize) -> usize { - self.logup_counts.get(layer_idx).copied().unwrap_or(1) + self.logup_counts.get(layer_idx).copied().unwrap_or(0) } #[inline] pub(crate) fn claim_tidx(&self, layer_idx: usize) -> usize { - self.layer_tidx(layer_idx) + tower_transcript_len::claim_offset_in_layer(layer_idx) + self.layer_tidx(layer_idx) + self.claim_offset_in_layer(layer_idx) + } + + #[inline] + pub(crate) fn out_eval_span(&self, layer_idx: usize) -> usize { + let words = 2 * self.read_count_at(layer_idx) + + 2 * self.write_count_at(layer_idx) + + 4 * self.logup_count_at(layer_idx); + words * D_EF + } + + #[inline] + pub(crate) fn claim_offset_in_layer(&self, layer_idx: usize) -> usize { + if layer_idx == 0 { + 0 + } else { + tower_transcript_len::SUMCHECK_INIT_LEN + layer_idx * tower_transcript_len::ROUND_LEN + } + } + + #[inline] + pub(crate) fn lambda_tidx(&self, layer_idx: usize) -> usize { + if layer_idx == 0 { + self.layer_tidx(0) + self.out_eval_span(0) + tower_transcript_len::LABEL_COMBINE + } else { + self.mu_tidx(layer_idx) + D_EF + tower_transcript_len::LABEL_COMBINE + } + } + + #[inline] + pub(crate) fn mu_tidx(&self, layer_idx: usize) -> usize { + if layer_idx == 0 { + self.lambda_tidx(0) + D_EF + tower_transcript_len::LABEL_PRODUCT_SUM + } else { + self.claim_tidx(layer_idx) + + self.out_eval_span(layer_idx) + + tower_transcript_len::LABEL_MERGE + } + } + + #[inline] + pub(crate) fn layer_span(&self, layer_idx: usize) -> usize { + if layer_idx == 0 { + self.out_eval_span(0) + tower_transcript_len::ALPHA_BETA_LEN + } else { + tower_transcript_len::SUMCHECK_INIT_LEN + + layer_idx * tower_transcript_len::ROUND_LEN + + self.out_eval_span(layer_idx) + + tower_transcript_len::MERGE_LEN + + tower_transcript_len::ALPHA_LEN + } } } +#[inline] +pub(crate) fn ext_pow(base: EF, exp: usize) -> EF { + (0..exp).fold(EF::ONE, |acc, _| acc * base) +} + pub struct TowerLayerTraceGenerator; impl RowMajorChip for TowerLayerTraceGenerator { - // (gkr_layer_records, mus, q0_claims) - type Ctx<'a> = (&'a [TowerLayerRecord], &'a [Vec], &'a [EF]); + type Ctx<'a> = (&'a [TowerLayerRecord], &'a [Vec]); #[tracing::instrument(level = "trace", skip_all)] fn generate_trace( @@ -147,9 +207,8 @@ impl RowMajorChip for TowerLayerTraceGenerator { ctx: &Self::Ctx<'_>, required_height: Option, ) -> Option> { - let (gkr_layer_records, mus, q0_claims) = ctx; + let (gkr_layer_records, mus) = ctx; debug_assert_eq!(gkr_layer_records.len(), mus.len()); - debug_assert_eq!(gkr_layer_records.len(), q0_claims.len()); let width = TowerLayerCols::::width(); let rows_per_proof: Vec = gkr_layer_records @@ -180,15 +239,14 @@ impl RowMajorChip for TowerLayerTraceGenerator { trace_slices .par_iter_mut() - .zip( - gkr_layer_records - .par_iter() - .zip(mus.par_iter()) - .zip(q0_claims.par_iter()), - ) - .for_each(|(chunk, ((record, mus_for_proof), q0_claim))| { - let q0_basis = q0_claim.as_basis_coefficients_slice(); + .zip(gkr_layer_records.par_iter().zip(mus.par_iter())) + .for_each(|(chunk, (record, mus_for_proof))| { let mus_for_proof = mus_for_proof.as_slice(); + let initial_tower_claim: [F; D_EF] = record + .initial_tower_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); if record.layer_claims.is_empty() { debug_assert_eq!(chunk.len(), width); @@ -197,6 +255,7 @@ impl RowMajorChip for TowerLayerTraceGenerator { cols.is_enabled = F::ONE; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.is_first_air_idx = F::from_bool(record.is_first_air_idx); cols.is_first = F::ONE; cols.is_dummy = F::ONE; @@ -214,13 +273,19 @@ impl RowMajorChip for TowerLayerTraceGenerator { cols.write_claim_prime = [F::ZERO; D_EF]; cols.logup_claim = [F::ZERO; D_EF]; cols.logup_claim_prime = [F::ZERO; D_EF]; + cols.read_eval_claim = [F::ZERO; D_EF]; + cols.write_eval_claim = [F::ZERO; D_EF]; + cols.logup_eval_claim = [F::ZERO; D_EF]; + cols.read_lambda_end = lambda_prime_one; + cols.read_lambda_prime_end = lambda_prime_one; + cols.write_lambda_end = lambda_prime_one; + cols.write_lambda_prime_end = lambda_prime_one; cols.num_read_count = F::ZERO; cols.num_write_count = F::ZERO; cols.num_logup_count = F::ZERO; + cols.num_layers = F::ZERO; cols.eq_at_r_prime = [F::ZERO; D_EF]; - cols.r0_claim.copy_from_slice(q0_basis); - cols.w0_claim.copy_from_slice(q0_basis); - cols.q0_claim.copy_from_slice(q0_basis); + cols.initial_tower_claim = initial_tower_claim; return; } @@ -235,6 +300,7 @@ impl RowMajorChip for TowerLayerTraceGenerator { cols.is_dummy = F::ZERO; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.is_first_air_idx = F::from_bool(layer_idx == 0 && record.is_first_air_idx); cols.is_first = F::from_bool(layer_idx == 0); cols.layer_idx = F::from_usize(layer_idx); @@ -272,33 +338,85 @@ impl RowMajorChip for TowerLayerTraceGenerator { .as_basis_coefficients_slice() .try_into() .unwrap(); - cols.num_read_count = F::from_usize(record.read_count_at(layer_idx).max(1)); - cols.num_write_count = F::from_usize(record.write_count_at(layer_idx).max(1)); - cols.num_logup_count = F::from_usize(record.logup_count_at(layer_idx).max(1)); + let read_count = record.read_count_at(layer_idx); + let write_count = record.write_count_at(layer_idx); + let logup_count = record.logup_count_at(layer_idx); + cols.num_read_count = F::from_usize(read_count); + cols.num_write_count = F::from_usize(write_count); + cols.num_logup_count = F::from_usize(logup_count); + cols.num_layers = F::from_usize(record.layer_count()); + let lambda = record.lambda_at(layer_idx); + let lambda_prime = record.lambda_prime_at(layer_idx); + let read_lambda_end = ext_pow(lambda, read_count); + let read_lambda_prime_end = ext_pow(lambda_prime, read_count); + let write_lambda_end = read_lambda_end * ext_pow(lambda, write_count); + let write_lambda_prime_end = + read_lambda_prime_end * ext_pow(lambda_prime, write_count); + cols.read_lambda_end = read_lambda_end + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.read_lambda_prime_end = read_lambda_prime_end + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.write_lambda_end = write_lambda_end + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.write_lambda_prime_end = write_lambda_prime_end + .as_basis_coefficients_slice() + .try_into() + .unwrap(); cols.eq_at_r_prime = record .eq_at(layer_idx) .as_basis_coefficients_slice() .try_into() .unwrap(); - cols.r0_claim.copy_from_slice(q0_basis); - cols.w0_claim.copy_from_slice(q0_basis); - cols.q0_claim.copy_from_slice(q0_basis); - if layer_idx == 0 { - cols.read_claim_prime.copy_from_slice(&cols.r0_claim); - cols.write_claim_prime.copy_from_slice(&cols.w0_claim); - cols.logup_claim_prime.copy_from_slice(&cols.q0_claim); + cols.initial_tower_claim = initial_tower_claim; + cols.read_claim_prime = + read_prime.as_basis_coefficients_slice().try_into().unwrap(); + cols.write_claim_prime = write_prime + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.logup_claim_prime = logup_prime + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let read_eval = if layer_idx == 0 { + EF::ZERO } else { - cols.read_claim_prime = - read_prime.as_basis_coefficients_slice().try_into().unwrap(); - cols.write_claim_prime = write_prime - .as_basis_coefficients_slice() - .try_into() - .unwrap(); - cols.logup_claim_prime = logup_prime - .as_basis_coefficients_slice() - .try_into() - .unwrap(); - } + record + .read_prime_claims + .get(layer_idx - 1) + .copied() + .unwrap_or(EF::ZERO) + }; + let write_eval = if layer_idx == 0 { + EF::ZERO + } else { + record + .write_prime_claims + .get(layer_idx - 1) + .copied() + .unwrap_or(EF::ZERO) + }; + let logup_eval = if layer_idx == 0 { + EF::ZERO + } else { + record + .logup_prime_claims + .get(layer_idx - 1) + .copied() + .unwrap_or(EF::ZERO) + }; + cols.read_eval_claim = + read_eval.as_basis_coefficients_slice().try_into().unwrap(); + cols.write_eval_claim = + write_eval.as_basis_coefficients_slice().try_into().unwrap(); + cols.logup_eval_claim = + logup_eval.as_basis_coefficients_slice().try_into().unwrap(); prev_folded_claim = Some(read_claim + write_claim + logup_claim); } diff --git a/ceno_recursion_v2/src/tower/mod.rs b/ceno_recursion_v2/src/tower/mod.rs index 15f48dbb8..0b7128de0 100644 --- a/ceno_recursion_v2/src/tower/mod.rs +++ b/ceno_recursion_v2/src/tower/mod.rs @@ -56,7 +56,9 @@ use openvm_stark_backend::{ AirRef, FiatShamirTranscript, ReadOnlyTranscript, StarkProtocolConfig, TranscriptHistory, p3_maybe_rayon::prelude::*, prover::AirProvingContext, }; -use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, D_EF, EF, F}; +#[cfg(test)] +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, EF, F}; use p3_field::PrimeCharacteristicRing; use p3_matrix::dense::RowMajorMatrix; use recursion_circuit::primitives::exp_bits_len::ExpBitsLenTraceGenerator; @@ -88,12 +90,16 @@ use eyre::Result; // Internal bus definitions mod bus; pub use bus::{ - TowerLogupClaimBus, TowerLogupClaimInputBus, TowerLogupClaimMessage, - TowerLogupLayerChallengeMessage, TowerProdLayerChallengeMessage, TowerProdReadClaimBus, - TowerProdReadClaimInputBus, TowerProdSumClaimMessage, TowerProdWriteClaimBus, - TowerProdWriteClaimInputBus, TowerSumcheckChallengeBus, TowerSumcheckChallengeMessage, - TowerSumcheckInputBus, TowerSumcheckInputMessage, TowerSumcheckOutputBus, - TowerSumcheckOutputMessage, + TowerLogupClaimBus, TowerLogupClaimInputBus, TowerLogupClaimMessage, TowerLogupInitBus, + TowerLogupInitMessage, TowerLogupLayerChallengeMessage, TowerLogupRootBus, + TowerLogupRootInputBus, TowerLogupRootInputMessage, TowerLogupRootMessage, + TowerProdInitMessage, TowerProdLayerInputMessage, TowerProdReadClaimBus, + TowerProdReadClaimInputBus, TowerProdRootInputMessage, TowerProdRootMessage, + TowerProdSumClaimMessage, TowerProdWriteClaimBus, TowerProdWriteClaimInputBus, + TowerReadInitBus, TowerReadRootBus, TowerReadRootInputBus, TowerSumcheckChallengeBus, + TowerSumcheckChallengeMessage, TowerSumcheckInputBus, TowerSumcheckInputMessage, + TowerSumcheckOutputBus, TowerSumcheckOutputMessage, TowerWriteInitBus, TowerWriteRootBus, + TowerWriteRootInputBus, }; /// Transcript field-element lengths per tower operation. @@ -107,15 +113,23 @@ pub mod tower_transcript_len { // Label field-element counts: ceil(byte_len / 4). // b"combine subset evals" = 20 bytes → 5 field elements - const LABEL_COMBINE: usize = 5; + pub const LABEL_COMBINE: usize = 5; // b"product_sum" = 11 bytes → 3 field elements - const LABEL_PRODUCT_SUM: usize = 3; + pub const LABEL_PRODUCT_SUM: usize = 3; // b"Internal round" = 14 bytes → 4 field elements - const LABEL_INTERNAL_ROUND: usize = 4; + pub const LABEL_INTERNAL_ROUND: usize = 4; // b"merge" = 5 bytes → 2 field elements - const LABEL_MERGE: usize = 2; + pub const LABEL_MERGE: usize = 2; // usize::to_le_bytes() = 8 bytes → 2 field elements (64-bit platform) - const LABEL_USIZE: usize = 2; + pub const LABEL_USIZE: usize = 2; + + pub const LABEL_COMBINE_VALUES: [usize; LABEL_COMBINE] = + [1651339107, 543518313, 1935832435, 1696625765, 1936482678]; + pub const LABEL_PRODUCT_SUM_VALUES: [usize; LABEL_PRODUCT_SUM] = + [1685025392, 1601463157, 7173491]; + pub const LABEL_INTERNAL_ROUND_VALUES: [usize; LABEL_INTERNAL_ROUND] = + [1702129225, 1818324594, 1970237984, 25710]; + pub const LABEL_MERGE_VALUES: [usize; LABEL_MERGE] = [1735550317, 101]; /// label "combine subset evals" (5) + sample alpha (D_EF) pub const ALPHA_LEN: usize = LABEL_COMBINE + D_EF; @@ -179,7 +193,7 @@ pub mod layer; pub mod sumcheck; #[allow(clippy::module_inception)] mod tower; -pub(crate) use tower::{TowerReplayResult, replay_tower_proof, replay_tower_proof_poseidon}; +pub(crate) use tower::{TowerReplayResult, replay_tower_proof_poseidon}; pub struct TowerModule { // Global bus inventory bus_inventory: BusInventory, @@ -195,6 +209,15 @@ pub struct TowerModule { prod_write_claim_bus: TowerProdWriteClaimBus, logup_claim_input_bus: TowerLogupClaimInputBus, logup_claim_bus: TowerLogupClaimBus, + read_root_input_bus: TowerReadRootInputBus, + read_root_bus: TowerReadRootBus, + read_init_bus: TowerReadInitBus, + write_root_input_bus: TowerWriteRootInputBus, + write_root_bus: TowerWriteRootBus, + write_init_bus: TowerWriteInitBus, + logup_root_input_bus: TowerLogupRootInputBus, + logup_root_bus: TowerLogupRootBus, + logup_init_bus: TowerLogupInitBus, } #[derive(Clone, Debug, Default)] @@ -202,18 +225,17 @@ pub(crate) struct TowerTowerEvalRecord { pub(crate) read_layers: Vec>, pub(crate) write_layers: Vec>, pub(crate) logup_layers: Vec>, + pub(crate) read_active: Vec>, + pub(crate) write_active: Vec>, + pub(crate) logup_active: Vec>, } pub(crate) struct TowerBlobCpu { - input_records: Vec, - /// Per-proof q0 claims matching input_records (one per proof). - proof_q0_claims: Vec, - layer_records: Vec, - tower_records: Vec, - sumcheck_records: Vec, - mus_records: Vec>, - /// Per-chip q0 claims matching layer_records. - q0_claims: Vec, + pub(crate) input_records: Vec, + pub(crate) layer_records: Vec, + pub(crate) tower_records: Vec, + pub(crate) sumcheck_records: Vec, + pub(crate) mus_records: Vec>, } #[derive(Debug, Clone, Default)] @@ -240,6 +262,15 @@ impl TowerModule { prod_write_claim_bus: TowerProdWriteClaimBus::new(b.new_bus_idx()), logup_claim_input_bus: TowerLogupClaimInputBus::new(b.new_bus_idx()), logup_claim_bus: TowerLogupClaimBus::new(b.new_bus_idx()), + read_root_input_bus: TowerReadRootInputBus::new(b.new_bus_idx()), + read_root_bus: TowerReadRootBus::new(b.new_bus_idx()), + read_init_bus: TowerReadInitBus::new(b.new_bus_idx()), + write_root_input_bus: TowerWriteRootInputBus::new(b.new_bus_idx()), + write_root_bus: TowerWriteRootBus::new(b.new_bus_idx()), + write_init_bus: TowerWriteInitBus::new(b.new_bus_idx()), + logup_root_input_bus: TowerLogupRootInputBus::new(b.new_bus_idx()), + logup_root_bus: TowerLogupRootBus::new(b.new_bus_idx()), + logup_init_bus: TowerLogupInitBus::new(b.new_bus_idx()), } } @@ -258,12 +289,15 @@ impl TowerModule { for (&chip_idx, chip_instances) in &proof.chip_proofs { for (instance_idx, chip_proof) in chip_instances.iter().enumerate() { let tidx = ts.len(); - let tower_replay = + let (_, tower_replay) = record_and_replay_tower_preflight(ts, child_vk, chip_idx, chip_proof); preflight.gkr.chips.push(TowerChipTranscriptRange { chip_idx, instance_idx, + num_layers: circuit_vk_for_idx(child_vk, chip_idx) + .map(|circuit_vk| tower_layer_count_from_vk(circuit_vk, chip_proof)) + .unwrap_or(0), tidx, fork_idx: 0, // unused in forked flow tower_replay, @@ -305,9 +339,20 @@ pub(crate) fn interpolate_pair(values: [EF; 2], mu: EF) -> EF { values[0] + delta * mu } -fn accumulate_prod_claims(rows: &[[EF; 2]], lambda: EF, lambda_prime: EF, mu: EF) -> (EF, EF) { - let mut pow_lambda = EF::ONE; - let mut pow_lambda_prime = EF::ONE; +fn ext_pow(base: EF, exp: usize) -> EF { + (0..exp).fold(EF::ONE, |acc, _| acc * base) +} + +fn accumulate_prod_claims( + rows: &[[EF; 2]], + lambda: EF, + lambda_prime: EF, + mu: EF, + lambda_start: EF, + lambda_prime_start: EF, +) -> (EF, EF) { + let mut pow_lambda = lambda_start; + let mut pow_lambda_prime = lambda_prime_start; let mut acc_sum = EF::ZERO; let mut acc_sum_prime = EF::ZERO; @@ -323,11 +368,18 @@ fn accumulate_prod_claims(rows: &[[EF; 2]], lambda: EF, lambda_prime: EF, mu: EF (acc_sum, acc_sum_prime) } -fn accumulate_logup_claims(rows: &[[EF; 4]], lambda: EF, lambda_prime: EF, mu: EF) -> (EF, EF) { - let mut pow_lambda = EF::ONE; - let mut pow_lambda_prime = EF::ONE; +fn accumulate_logup_claims( + rows: &[[EF; 4]], + lambda: EF, + lambda_prime: EF, + mu: EF, + lambda_start: EF, + lambda_prime_start: EF, +) -> (EF, EF) { + let mut pow_lambda = lambda_start; + let mut pow_lambda_prime = lambda_prime_start; let mut acc_sum = EF::ZERO; - let mut acc_q = EF::ZERO; + let mut acc_eval = EF::ZERO; for quad in rows { let p_vals = [quad[0], quad[1]]; @@ -335,13 +387,14 @@ fn accumulate_logup_claims(rows: &[[EF; 4]], lambda: EF, lambda_prime: EF, mu: E let p_xi = interpolate_pair(p_vals, mu); let q_xi = interpolate_pair(q_vals, mu); acc_sum += pow_lambda * (p_xi + lambda * q_xi); + let p_cross = quad[0] * quad[3] + quad[1] * quad[2]; let q_cross = quad[2] * quad[3]; - acc_q += pow_lambda_prime * lambda_prime * q_cross; - pow_lambda *= lambda; - pow_lambda_prime *= lambda_prime; + acc_eval += pow_lambda_prime * (p_cross + lambda_prime * q_cross); + pow_lambda *= lambda * lambda; + pow_lambda_prime *= lambda_prime * lambda_prime; } - (acc_sum, acc_q) + (acc_sum, acc_eval) } pub(crate) fn circuit_vk_for_idx( @@ -353,6 +406,20 @@ pub(crate) fn circuit_vk_for_idx( .and_then(|name| vk.circuit_vks.get(name)) } +pub(crate) fn tower_layer_count_from_vk( + circuit_vk: &VerifyingKey, + chip_proof: &ZKVMChipProof, +) -> usize { + let proof_layer_count = chip_proof.tower_proof.proofs.len(); + let cs = &circuit_vk.cs; + let has_root_specs = cs.num_reads() + cs.num_writes() + cs.num_lks() > 0; + if proof_layer_count == 0 && !has_root_specs { + 0 + } else { + proof_layer_count + 1 + } +} + /// Record all tower transcript events for one chip proof, then replay tower proof. /// Keeping this in the tower module avoids preflight callsites duplicating /// transcript/replay wiring logic. @@ -361,23 +428,49 @@ pub(crate) fn record_and_replay_tower_preflight( child_vk: &RecursionVk, chip_idx: usize, chip_proof: &ZKVMChipProof, -) -> TowerReplayResult +) -> (TowerTranscriptSchedule, TowerReplayResult) where TS: FiatShamirTranscript, { - let _ = record_gkr_transcript(ts, chip_idx, chip_proof); - match circuit_vk_for_idx(child_vk, chip_idx) { - Some(circuit_vk) => match replay_tower_proof(chip_proof, circuit_vk) { + let schedule = record_gkr_transcript(ts, chip_idx, chip_proof); + let replay = match circuit_vk_for_idx(child_vk, chip_idx) { + Some(circuit_vk) => match replay_tower_proof_poseidon(chip_proof, circuit_vk, &schedule) { Ok(replay) => replay, Err(err) => { error!( ?err, - chip_idx, "failed to replay tower proof during preflight" + chip_idx, "failed to replay Poseidon tower proof during preflight" ); TowerReplayResult::default() } }, None => TowerReplayResult::default(), + }; + (schedule, replay) +} + +pub(crate) fn derive_tower_input_claim_for_transcript( + child_vk: &RecursionVk, + chip_idx: usize, + chip_proof: &ZKVMChipProof, + replay: &TowerReplayResult, + schedule: &TowerTranscriptSchedule, +) -> EF { + let Some(circuit_vk) = circuit_vk_for_idx(child_vk, chip_idx) else { + return EF::ZERO; + }; + + match build_chip_records( + 0, 0, chip_idx, 0, true, chip_proof, circuit_vk, replay, schedule, 0, + ) { + Ok((input_record, ..)) => input_record.input_layer_claim, + Err(err) => { + error!( + ?err, + chip_idx, "failed to derive tower input claim during preflight" + ); + EF::ZERO + } } } @@ -385,6 +478,8 @@ where fn build_chip_records( proof_idx: usize, idx: usize, + chip_idx: usize, + fork_idx: usize, is_first_air_idx: bool, chip_proof: &ZKVMChipProof, _circuit_vk: &VerifyingKey, @@ -397,8 +492,11 @@ fn build_chip_records( TowerTowerEvalRecord, TowerSumcheckRecord, Vec, - EF, )> { + let cs = &_circuit_vk.cs; + let read_count = cs.num_reads(); + let write_count = cs.num_writes(); + let logup_count = cs.num_lks(); let spec_layer_count = chip_proof .tower_proof .logup_specs_eval @@ -407,42 +505,101 @@ fn build_chip_records( .chain(chip_proof.tower_proof.prod_specs_eval.iter().map(Vec::len)) .max() .unwrap_or(0); - let layer_count = replay.layers.len().max(spec_layer_count); + let proof_layer_count = chip_proof.tower_proof.proofs.len(); + let layer_count = tower_layer_count_from_vk(_circuit_vk, chip_proof); + let _ = spec_layer_count; + eyre::ensure!( + chip_proof.r_out_evals.len() == read_count, + "read root eval count mismatch at proof {proof_idx} chip {chip_idx}: proof={}, vk={read_count}", + chip_proof.r_out_evals.len() + ); + eyre::ensure!( + chip_proof.w_out_evals.len() == write_count, + "write root eval count mismatch at proof {proof_idx} chip {chip_idx}: proof={}, vk={write_count}", + chip_proof.w_out_evals.len() + ); + eyre::ensure!( + chip_proof.lk_out_evals.len() == logup_count, + "logup root eval count mismatch at proof {proof_idx} chip {chip_idx}: proof={}, vk={logup_count}", + chip_proof.lk_out_evals.len() + ); - let read_count = chip_proof.r_out_evals.len(); - let write_count = chip_proof.w_out_evals.len(); - let logup_count = chip_proof.lk_out_evals.len(); + let mut read_layers = vec![vec![[EF::ZERO; 2]; read_count]; layer_count]; + let mut write_layers = vec![vec![[EF::ZERO; 2]; write_count]; layer_count]; + let mut logup_layers = vec![vec![[EF::ZERO; 4]; logup_count]; layer_count]; + let mut read_active = vec![vec![false; read_count]; layer_count]; + let mut write_active = vec![vec![false; write_count]; layer_count]; + let mut logup_active = vec![vec![false; logup_count]; layer_count]; - let mut read_layers = vec![Vec::with_capacity(read_count); layer_count]; - let mut write_layers = vec![Vec::with_capacity(write_count); layer_count]; - let mut logup_layers = vec![Vec::with_capacity(logup_count); layer_count]; + if layer_count > 0 { + for (spec_idx, evals) in chip_proof.r_out_evals.iter().enumerate() { + if spec_idx < read_count { + let mut pair = [EF::ZERO; 2]; + for (dst, src) in pair.iter_mut().zip(evals.iter().take(2)) { + *dst = *src; + } + read_layers[0][spec_idx] = pair; + read_active[0][spec_idx] = true; + } + } + for (spec_idx, evals) in chip_proof.w_out_evals.iter().enumerate() { + if spec_idx < write_count { + let mut pair = [EF::ZERO; 2]; + for (dst, src) in pair.iter_mut().zip(evals.iter().take(2)) { + *dst = *src; + } + write_layers[0][spec_idx] = pair; + write_active[0][spec_idx] = true; + } + } + for (spec_idx, evals) in chip_proof.lk_out_evals.iter().enumerate() { + if spec_idx < logup_count { + let mut quad = [EF::ZERO; 4]; + for (dst, src) in quad.iter_mut().zip(evals.iter().take(4)) { + *dst = *src; + } + logup_layers[0][spec_idx] = quad; + logup_active[0][spec_idx] = true; + } + } + } for (spec_idx, rounds) in chip_proof.tower_proof.prod_specs_eval.iter().enumerate() { - for layer_idx in 0..layer_count { - let mut pair = [EF::ZERO; 2]; - if let Some(values) = rounds.get(layer_idx) { + for round_idx in 0..proof_layer_count { + if let Some(values) = rounds.get(round_idx) { + let layer_idx = round_idx + 1; + let mut pair = [EF::ZERO; 2]; for (dst, src) in pair.iter_mut().zip(values.iter().take(2)) { *dst = *src; } - } - if spec_idx < read_count { - read_layers[layer_idx].push(pair); - } else { - write_layers[layer_idx].push(pair); + if spec_idx < read_count { + read_layers[layer_idx][spec_idx] = pair; + read_active[layer_idx][spec_idx] = true; + } else { + let write_idx = spec_idx - read_count; + if write_idx < write_count { + write_layers[layer_idx][write_idx] = pair; + write_active[layer_idx][write_idx] = true; + } + } } } } - for rounds in &chip_proof.tower_proof.logup_specs_eval { + for (spec_idx, rounds) in chip_proof.tower_proof.logup_specs_eval.iter().enumerate() { #[allow(clippy::needless_range_loop)] - for layer_idx in 0..layer_count { - let mut quad = [EF::ZERO; 4]; - if let Some(values) = rounds.get(layer_idx) { + for round_idx in 0..proof_layer_count { + if let Some(values) = rounds.get(round_idx) { + let layer_idx = round_idx + 1; + let mut quad = [EF::ZERO; 4]; for (dst, src) in quad.iter_mut().zip(values.iter().take(4)) { *dst = *src; } + if spec_idx < logup_count { + logup_layers[layer_idx][spec_idx] = quad; + logup_active[layer_idx][spec_idx] = true; + } } - logup_layers[layer_idx].push(quad); } } @@ -450,14 +607,18 @@ fn build_chip_records( read_layers, write_layers, logup_layers, + read_active, + write_active, + logup_active, }; let mut layer_record = TowerLayerRecord { proof_idx, idx, + chip_id: chip_idx, is_first_air_idx, - // TowerLayerAir starts after alpha/beta labels+sampling. - tidx: tidx + tower_transcript_len::ALPHA_BETA_LEN, + tidx, + initial_tower_claim: EF::ZERO, layer_claims: Vec::with_capacity(layer_count), lambdas: vec![EF::ZERO; layer_count], eq_at_r_primes: vec![EF::ZERO; layer_count], @@ -494,9 +655,9 @@ fn build_chip_records( // read_len == write_len, // "read/write prod spec count mismatch at layer {layer_idx}: read={read_len}, write={write_len}" // ); - layer_record.read_counts[layer_idx] = read_len.max(1); - layer_record.write_counts[layer_idx] = write_len.max(1); - layer_record.logup_counts[layer_idx] = logup_len.max(1); + layer_record.read_counts[layer_idx] = read_len; + layer_record.write_counts[layer_idx] = write_len; + layer_record.logup_counts[layer_idx] = logup_len; } for layer_idx in 0..layer_count { @@ -505,22 +666,13 @@ fn build_chip_records( .push(convert_logup_claim(chip_proof, layer_idx)); } - let input_layer_claim = layer_record - .layer_claims - .last() - .map(|claim| claim[0]) - .unwrap_or(EF::ZERO); - let mut sumcheck_record = TowerSumcheckRecord { proof_idx, idx, + chip_id: chip_idx, is_first_air_idx, - // First sumcheck transcript row starts at layer_tidx(1) + ALPHA_LEN + SUMCHECK_INIT_LEN. - tidx: tidx - + tower_transcript_len::ALPHA_BETA_LEN - + tower_transcript_len::POST_SUMCHECK_LEN - + tower_transcript_len::ALPHA_LEN - + tower_transcript_len::SUMCHECK_INIT_LEN, + tidx: 0, + layer_tidxs: Vec::new(), evals: Vec::new(), ris: Vec::new(), claims: vec![EF::ZERO; layer_count.saturating_sub(1)], @@ -546,23 +698,50 @@ fn build_chip_records( } } let mut mus_record = vec![EF::ZERO; layer_count]; + if !mus_record.is_empty() { + mus_record[0] = schedule.beta; + } + for layer_idx in 0..layer_count { + layer_record.lambdas[layer_idx] = + schedule.lambdas.get(layer_idx).copied().unwrap_or(EF::ZERO); + if layer_idx > 0 { + mus_record[layer_idx] = schedule.mus.get(layer_idx - 1).copied().unwrap_or(EF::ZERO); + } + } - let q0_claim = chip_proof - .lk_out_evals - .first() - .and_then(|evals| evals.get(2)) - .copied() - .unwrap_or(EF::ZERO); - - let layer_output_lambda = schedule.lambdas.last().copied().unwrap_or(EF::ZERO); - let layer_output_mu = schedule.mus.last().copied().unwrap_or(EF::ZERO); - let input_record = TowerInputRecord { + let layer_output_lambda = if layer_count == 0 { + EF::ZERO + } else { + schedule.lambdas.last().copied().unwrap_or(EF::ZERO) + }; + let layer_output_mu = if layer_count == 0 { + EF::ZERO + } else { + schedule.mus.last().copied().unwrap_or(EF::ZERO) + }; + let mut input_record = TowerInputRecord { proof_idx, idx, + chip_id: chip_idx, tidx, - n_logup: layer_count, + final_tidx: tidx, + num_layers: layer_count, + num_read_specs: read_count, + num_write_specs: write_count, + num_logup_specs: logup_count, + r0_claim: EF::ZERO, + w0_claim: EF::ZERO, + p0_claim: EF::ZERO, + q0_claim: EF::ONE, alpha_logup: schedule.alpha_logup, - input_layer_claim, + r_1: schedule.beta, + read_initial_claim: EF::ZERO, + write_initial_claim: EF::ZERO, + logup_initial_claim: EF::ZERO, + initial_tower_claim: EF::ZERO, + write_lambda_1_start: ext_pow(schedule.alpha_logup, read_count), + logup_lambda_1_start: ext_pow(schedule.alpha_logup, read_count + write_count), + input_layer_claim: EF::ZERO, layer_output_lambda, layer_output_mu, }; @@ -576,20 +755,16 @@ fn build_chip_records( sumcheck_record.evals.len() ); } - for (layer_idx, data) in replay.layers.iter().enumerate() { + for (round_idx, data) in replay.layers.iter().enumerate() { + let layer_idx = round_idx + 1; if layer_idx < layer_record.eq_at_r_primes.len() { layer_record.eq_at_r_primes[layer_idx] = data.eq_at_r; - layer_record.lambdas[layer_idx] = - schedule.lambdas.get(layer_idx).copied().unwrap_or(EF::ZERO); - mus_record[layer_idx] = schedule.mus.get(layer_idx).copied().unwrap_or(EF::ZERO); } - if layer_idx + 1 < layer_count { - if layer_idx < sumcheck_record.claims.len() { - sumcheck_record.claims[layer_idx] = data.claim_in; - } - if layer_idx < layer_record.sumcheck_claims.len() { - layer_record.sumcheck_claims[layer_idx] = data.claim_in; - } + if round_idx < sumcheck_record.claims.len() { + sumcheck_record.claims[round_idx] = data.claim_in; + } + if round_idx < layer_record.sumcheck_claims.len() { + layer_record.sumcheck_claims[round_idx] = data.claim_in; } } @@ -601,23 +776,90 @@ fn build_chip_records( .unwrap_or(EF::ZERO); let lambda_prime = layer_record.lambda_prime_at(layer_idx); let mu = mus_record.get(layer_idx).copied().unwrap_or(EF::ZERO); + let read_count = layer_record.read_count_at(layer_idx); + let write_count = layer_record.write_count_at(layer_idx); + let read_lambda_start = EF::ONE; + let read_lambda_prime_start = EF::ONE; + let write_lambda_start = ext_pow(lambda, read_count); + let write_lambda_prime_start = ext_pow(lambda_prime, read_count); + let logup_lambda_start = ext_pow(lambda, read_count + write_count); + let logup_lambda_prime_start = ext_pow(lambda_prime, read_count + write_count); if let Some(rows) = tower_record.read_layers.get(layer_idx) { - let (claim, prime) = accumulate_prod_claims(rows, lambda, lambda_prime, mu); + let (claim, prime) = accumulate_prod_claims( + rows, + lambda, + lambda_prime, + mu, + read_lambda_start, + read_lambda_prime_start, + ); layer_record.read_claims[layer_idx] = claim; layer_record.read_prime_claims[layer_idx] = prime; + if layer_idx == 0 { + input_record.read_initial_claim = claim; + input_record.r0_claim = rows + .iter() + .zip(tower_record.read_active[layer_idx].iter()) + .filter_map(|(pair, is_active)| is_active.then_some(pair[0] * pair[1])) + .product::(); + } } if let Some(rows) = tower_record.write_layers.get(layer_idx) { - let (claim, prime) = accumulate_prod_claims(rows, lambda, lambda_prime, mu); + let (claim, prime) = accumulate_prod_claims( + rows, + lambda, + lambda_prime, + mu, + write_lambda_start, + write_lambda_prime_start, + ); layer_record.write_claims[layer_idx] = claim; layer_record.write_prime_claims[layer_idx] = prime; + if layer_idx == 0 { + input_record.write_initial_claim = claim; + input_record.w0_claim = rows + .iter() + .zip(tower_record.write_active[layer_idx].iter()) + .filter_map(|(pair, is_active)| is_active.then_some(pair[0] * pair[1])) + .product::(); + } } if let Some(rows) = tower_record.logup_layers.get(layer_idx) { - let (claim, prime) = accumulate_logup_claims(rows, lambda, lambda_prime, mu); + let (claim, prime) = accumulate_logup_claims( + rows, + lambda, + lambda_prime, + mu, + logup_lambda_start, + logup_lambda_prime_start, + ); layer_record.logup_claims[layer_idx] = claim; layer_record.logup_prime_claims[layer_idx] = prime; + if layer_idx == 0 { + input_record.logup_initial_claim = claim; + let mut p0 = EF::ZERO; + let mut q0 = EF::ONE; + for (quad, is_active) in + rows.iter().zip(tower_record.logup_active[layer_idx].iter()) + { + if !*is_active { + continue; + } + let p_cross = quad[0] * quad[3] + quad[1] * quad[2]; + let q_cross = quad[2] * quad[3]; + p0 = p0 * q_cross + p_cross * q0; + q0 *= q_cross; + } + input_record.p0_claim = p0; + input_record.q0_claim = q0; + } } } + input_record.initial_tower_claim = input_record.read_initial_claim + + input_record.write_initial_claim + + input_record.logup_initial_claim; + layer_record.initial_tower_claim = input_record.initial_tower_claim; // Sync sumcheck claims with accumulated values so that the sumcheck trace // uses the same claim_in that TowerLayerAir sends on the sumcheck_input_bus. @@ -627,20 +869,61 @@ fn build_chip_records( let folded = layer_record.read_claims[k] + layer_record.write_claims[k] + layer_record.logup_claims[k]; + if let Some(replay_layer) = replay.layers.get(k) { + eyre::ensure!( + folded == replay_layer.claim_in, + "tower folded claim mismatch at proof {proof_idx} chip {idx} layer {k}: folded={folded:?}, replay={:?}", + replay_layer.claim_in + ); + } sumcheck_record.claims[k] = folded; layer_record.sumcheck_claims[k] = folded; } + if let Some(last_layer_idx) = layer_count.checked_sub(1) { + input_record.input_layer_claim = layer_record.read_claims[last_layer_idx] + + layer_record.write_claims[last_layer_idx] + + layer_record.logup_claims[last_layer_idx]; + input_record.layer_output_lambda = layer_record.lambdas[last_layer_idx]; + input_record.layer_output_mu = mus_record[last_layer_idx]; + input_record.final_tidx = + layer_record.layer_tidx(last_layer_idx) + layer_record.layer_span(last_layer_idx); + } + // Compute eq_at_r_primes from ris and mus so that TowerLayerAir's eq values // match the sumcheck trace's eq_out on the sumcheck_output_bus. // Sumcheck internal layer k (0-indexed) → TowerLayerAir layer k+1. let num_sumcheck_layers = layer_count.saturating_sub(1); + sumcheck_record.layer_tidxs = (0..num_sumcheck_layers) + .map(|k| layer_record.layer_tidx(k + 1) + tower_transcript_len::SUMCHECK_INIT_LEN) + .collect(); + if let Some(&first_tidx) = sumcheck_record.layer_tidxs.first() { + sumcheck_record.tidx = first_tidx; + } for k in 0..num_sumcheck_layers { let eq = TowerSumcheckRecord::compute_eq_for_layer(k, &mus_record, &sumcheck_record.ris); if k + 1 < layer_record.eq_at_r_primes.len() { layer_record.eq_at_r_primes[k + 1] = eq; } } + for (round_idx, replay_layer) in replay.layers.iter().enumerate() { + let layer_idx = round_idx + 1; + if layer_idx < layer_record.layer_count() { + let expected = layer_record.eq_at_r_primes[layer_idx] + * (layer_record.read_prime_claims[layer_idx] + + layer_record.write_prime_claims[layer_idx] + + layer_record.logup_prime_claims[layer_idx]); + eyre::ensure!( + expected == replay_layer.claim_out, + "tower expected-eval mismatch at proof {proof_idx} idx {idx} chip_idx {chip_idx} fork_idx {fork_idx} layer {layer_idx}: expected={expected:?}, replay={:?}, eq={:?}, read_prime={:?}, write_prime={:?}, logup_prime={:?}", + replay_layer.claim_out, + layer_record.eq_at_r_primes[layer_idx], + layer_record.read_prime_claims[layer_idx], + layer_record.write_prime_claims[layer_idx], + layer_record.logup_prime_claims[layer_idx], + ); + } + } Ok(( input_record, @@ -648,7 +931,6 @@ fn build_chip_records( tower_record, sumcheck_record, mus_record, - q0_claim, )) } @@ -660,10 +942,20 @@ impl AirModule for TowerModule { fn airs>(&self) -> Vec> { let gkr_input_air = TowerInputAir { tower_module_bus: self.bus_inventory.tower_module_bus, + tower_root_claim_bus: self.bus_inventory.tower_root_claim_bus, main_bus: self.bus_inventory.main_bus, transcript_bus: self.bus_inventory.transcript_bus, layer_input_bus: self.layer_input_bus, layer_output_bus: self.layer_output_bus, + read_root_input_bus: self.read_root_input_bus, + read_root_bus: self.read_root_bus, + read_init_bus: self.read_init_bus, + write_root_input_bus: self.write_root_input_bus, + write_root_bus: self.write_root_bus, + write_init_bus: self.write_init_bus, + logup_root_input_bus: self.logup_root_input_bus, + logup_root_bus: self.logup_root_bus, + logup_init_bus: self.logup_init_bus, }; let gkr_layer_air = TowerLayerAir { @@ -686,18 +978,27 @@ impl AirModule for TowerModule { transcript_bus: self.bus_inventory.transcript_bus, prod_claim_input_bus: self.prod_read_claim_input_bus, prod_claim_bus: self.prod_read_claim_bus, + root_input_bus: self.read_root_input_bus, + root_bus: self.read_root_bus, + init_bus: self.read_init_bus, }; let gkr_prod_write_sum_air = TowerProdWriteSumCheckClaimAir { transcript_bus: self.bus_inventory.transcript_bus, prod_claim_input_bus: self.prod_write_claim_input_bus, prod_claim_bus: self.prod_write_claim_bus, + root_input_bus: self.write_root_input_bus, + root_bus: self.write_root_bus, + init_bus: self.write_init_bus, }; let gkr_logup_sum_air = TowerLogupSumCheckClaimAir { transcript_bus: self.bus_inventory.transcript_bus, logup_claim_input_bus: self.logup_claim_input_bus, logup_claim_bus: self.logup_claim_bus, + root_input_bus: self.logup_root_input_bus, + root_bus: self.logup_root_bus, + init_bus: self.logup_init_bus, }; let gkr_sumcheck_air = TowerLayerSumcheckAir::new( @@ -739,12 +1040,10 @@ pub(crate) fn build_gkr_blob( preflights: &[Preflight], ) -> Result { let mut input_records = Vec::new(); - let mut proof_q0_claims = Vec::new(); let mut layer_records = Vec::new(); let mut tower_records = Vec::new(); let mut sumcheck_records = Vec::new(); let mut mus_records = Vec::new(); - let mut q0_claims = Vec::new(); eyre::ensure!( proofs.len() == preflights.len(), @@ -753,11 +1052,6 @@ pub(crate) fn build_gkr_blob( for (proof_idx, (proof, preflight)) in proofs.iter().zip(preflights).enumerate() { let mut has_chip = false; - let mut first_chip_alpha = EF::ZERO; - let mut first_chip_q0 = EF::ZERO; - let mut last_input_layer_claim = EF::ZERO; - let mut last_layer_output_lambda = EF::ZERO; - let mut last_layer_output_mu = EF::ZERO; let sorted_idx_by_chip: std::collections::BTreeMap = preflight .proof_shape @@ -809,58 +1103,31 @@ pub(crate) fn build_gkr_blob( let idx = entry_idx; // Compute global tidx from fork-local tidx for trace column values. let global_tidx = preflight.fork_global_offset(pf_entry.fork_idx) + pf_entry.tidx; - let ( - chip_input_record, - layer_record, - tower_record, - sumcheck_record, - mus_record, - q0_claim, - ) = build_chip_records( - proof_idx, - idx, - entry_idx == 0, - chip_proof, - circuit_vk, - &poseidon_replay, - &schedule, - global_tidx, - )?; - - // Capture first chip's alpha and q0 for the proof-level record - if entry_idx == 0 { - first_chip_alpha = chip_input_record.alpha_logup; - first_chip_q0 = q0_claim; - } - // Always update to latest chip for combined values - last_input_layer_claim = chip_input_record.input_layer_claim; - last_layer_output_lambda = chip_input_record.layer_output_lambda; - last_layer_output_mu = chip_input_record.layer_output_mu; - - // Per-chip records (not input_records) + let (chip_input_record, layer_record, tower_record, sumcheck_record, mus_record) = + build_chip_records( + proof_idx, + idx, + chip_idx, + pf_entry.fork_idx, + entry_idx == 0, + chip_proof, + circuit_vk, + &poseidon_replay, + &schedule, + global_tidx, + )?; + + input_records.push(chip_input_record); layer_records.push(layer_record); tower_records.push(tower_record); sumcheck_records.push(sumcheck_record); mus_records.push(mus_record); - q0_claims.push(q0_claim); } - // ONE input record per proof (matching ProofIdxSubAir constraint) - input_records.push(TowerInputRecord { - proof_idx, - idx: 0, - tidx: preflight.proof_shape.post_tidx, - n_logup: preflight.proof_shape.n_logup, - alpha_logup: first_chip_alpha, - input_layer_claim: last_input_layer_claim, - layer_output_lambda: last_layer_output_lambda, - layer_output_mu: last_layer_output_mu, - }); - proof_q0_claims.push(first_chip_q0); - if !has_chip { layer_records.push(TowerLayerRecord { idx: 0, + chip_id: 0, proof_idx, is_first_air_idx: true, ..Default::default() @@ -869,32 +1136,20 @@ pub(crate) fn build_gkr_blob( sumcheck_records.push(TowerSumcheckRecord { proof_idx, idx: 0, + chip_id: 0, is_first_air_idx: true, ..Default::default() }); mus_records.push(vec![]); - q0_claims.push(EF::ZERO); } } - if input_records.is_empty() { - input_records.push(TowerInputRecord::default()); - proof_q0_claims.push(EF::ZERO); - layer_records.push(TowerLayerRecord::default()); - sumcheck_records.push(TowerSumcheckRecord::default()); - tower_records.push(TowerTowerEvalRecord::default()); - mus_records.push(vec![]); - q0_claims.push(EF::ZERO); - } - Ok(TowerBlobCpu { input_records, - proof_q0_claims, layer_records, tower_records, sumcheck_records, mus_records, - q0_claims, }) } @@ -934,53 +1189,47 @@ where // Reconstruct the transcript events consumed by tower-related AIRs. // This keeps preflight transcript history aligned with TowerLayer/Sumcheck/ // ProdClaim/LogupClaim transcript bus interactions. - let read_count = chip_proof.r_out_evals.len(); - let layer_count = chip_proof - .tower_proof - .logup_specs_eval - .iter() - .map(Vec::len) - .chain(chip_proof.tower_proof.prod_specs_eval.iter().map(Vec::len)) - .max() - .unwrap_or(0); + let round_count = chip_proof.tower_proof.proofs.len(); let log2_num_fanin: usize = 1; // ceil_log2(NUM_FANIN=2) = 1 - let mut lambdas = Vec::with_capacity(layer_count); - let mut mus = Vec::with_capacity(layer_count); + let mut lambdas = Vec::with_capacity(round_count + 1); + lambdas.push(alpha_logup); + let mut mus = Vec::with_capacity(round_count); let mut ris = Vec::new(); - for layer_idx in 0..layer_count { - // For layer 0, there is no transcript lambda sample — the native verifier - // goes straight from beta to sumcheck. Use alpha_logup as the weighting - // challenge for the root layer (matching native's initial alpha_pows). - // For layers > 0, this sample corresponds to get_challenge_pows in the - // native verifier (the "next alpha" after the previous round's merge). - let lambda = if layer_idx > 0 { - transcript_observe_label(ts, b"combine subset evals"); - FiatShamirTranscript::::sample_ext(ts) - } else { - alpha_logup - }; - lambdas.push(lambda); - - if let Some(round_msgs) = chip_proof.tower_proof.proofs.get(layer_idx) { - // Mirror native sumcheck IOPVerifierState::verify init: - // append_message(max_num_variables.to_leBytes()) - // append_message(max_degree.to_leBytes()) - let max_num_variables = (layer_idx + 1) * log2_num_fanin; - let max_degree: usize = 3; // NUM_FANIN + 1 - transcript_observe_label(ts, &max_num_variables.to_le_bytes()); - transcript_observe_label(ts, &max_degree.to_le_bytes()); - - for (_ri_idx, msg) in round_msgs.iter().enumerate() { - for eval in &msg.evaluations { + for round_idx in 0..round_count { + let round_msgs = &chip_proof.tower_proof.proofs[round_idx]; + // Mirror native sumcheck IOPVerifierState::verify init: + // append_message(max_num_variables.to_leBytes()) + // append_message(max_degree.to_leBytes()) + let max_num_variables = (round_idx + 1) * log2_num_fanin; + let max_degree: usize = 3; // NUM_FANIN + 1 + transcript_observe_label(ts, &max_num_variables.to_le_bytes()); + transcript_observe_label(ts, &max_degree.to_le_bytes()); + + for msg in round_msgs { + for eval in &msg.evaluations { + ts.observe_ext(*eval); + } + // Mirror native: sample_and_append_challenge(b"Internal round") + transcript_observe_label(ts, b"Internal round"); + let ri = FiatShamirTranscript::::sample_ext(ts); + ris.push(ri); + } + + for rounds in &chip_proof.tower_proof.prod_specs_eval { + if let Some(evals) = rounds.get(round_idx) { + for eval in evals { + ts.observe_ext(*eval); + } + } + } + for rounds in &chip_proof.tower_proof.logup_specs_eval { + if let Some(evals) = rounds.get(round_idx) { + for eval in evals { ts.observe_ext(*eval); } - // Mirror native: sample_and_append_challenge(b"Internal round") - transcript_observe_label(ts, b"Internal round"); - let ri = FiatShamirTranscript::::sample_ext(ts); - ris.push(ri); } } @@ -988,9 +1237,12 @@ where transcript_observe_label(ts, b"merge"); let mu = FiatShamirTranscript::::sample_ext(ts); mus.push(mu); + + transcript_observe_label(ts, b"combine subset evals"); + let next_lambda = FiatShamirTranscript::::sample_ext(ts); + lambdas.push(next_lambda); } - let _ = read_count; TowerTranscriptSchedule { alpha_logup, beta, @@ -1044,6 +1296,510 @@ impl> TraceGenModule } } +#[cfg(test)] +mod debug_tests { + use super::*; + use crate::{ + system::RecursionPcs, + utils::{TranscriptLabel, transcript_observe_label}, + }; + use ceno_zkvm::scheme::{constants::NUM_FANIN, verifier::TowerVerify}; + use mpcs::PolynomialCommitmentScheme; + use multilinear_extensions::util::ceil_log2; + use openvm_stark_sdk::config::baby_bear_poseidon2::default_duplex_sponge_recorder; + use p3_field::BasedVectorSpace; + use transcript::{Transcript, basic::BasicTranscript}; + use witness::next_pow2_instance_padding; + + fn limbs(value: RecursionField) -> [F; D_EF] { + value.as_basis_coefficients_slice().try_into().unwrap() + } + + fn fixture_path(file_name: &str) -> Option { + std::env::var_os("CENO_RECURSION_V2_FIXTURE_DIR") + .map(std::path::PathBuf::from) + .into_iter() + .chain([std::path::PathBuf::from("./src/imported")]) + .map(|dir| dir.join(file_name)) + .find(|path| path.exists()) + } + + fn load_fixture() -> Option<(RecursionProof, RecursionVk)> { + let proof_path = fixture_path("proof.bin")?; + let vk_path = fixture_path("vk.bin")?; + let proof_bytes = std::fs::read(proof_path).ok()?; + let proof = bincode::deserialize::>(&proof_bytes) + .ok() + .and_then(|proofs| proofs.into_iter().next()) + .or_else(|| bincode::deserialize::(&proof_bytes).ok())?; + let mut vk = bincode::deserialize::(&std::fs::read(vk_path).ok()?).ok()?; + vk.rebuild_circuit_index(); + Some((proof, vk)) + } + + fn observe_basic_prefix( + ts: &mut BasicTranscript, + vk: &RecursionVk, + proof: &RecursionProof, + ) { + ts.append_field_element_exts(&vk.compute_digest()); + for (_, circuit_vk) in vk.circuit_vks.iter() { + for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance.iter() { + ts.append_field_element( + &proof + .public_values + .query_by_index::(instance_value.0), + ); + } + } + if let Some(commitment) = vk.fixed_commit.as_ref() { + RecursionPcs::write_commitment(commitment, ts).unwrap(); + } + if let Some(commitment) = vk.fixed_no_omc_init_commit.as_ref() { + RecursionPcs::write_commitment(commitment, ts).unwrap(); + } + RecursionPcs::write_commitment(&proof.witin_commit, ts).unwrap(); + } + + fn observe_basic_tower( + ts: &mut BasicTranscript, + chip_proof: &ZKVMChipProof, + ) -> TowerTranscriptSchedule { + for eval in chip_proof + .r_out_evals + .iter() + .chain(chip_proof.w_out_evals.iter()) + .chain(chip_proof.lk_out_evals.iter()) + .flatten() + { + ts.append_field_element_ext(eval); + } + let alpha_logup = ::sumcheck::util::get_challenge_pows::( + chip_proof.r_out_evals.len() + + chip_proof.w_out_evals.len() + + 2 * chip_proof.lk_out_evals.len(), + ts, + ) + .get(1) + .copied() + .unwrap_or(RecursionField::ONE); + let beta = ts.sample_and_append_vec(b"product_sum", 1)[0]; + let mut lambdas = vec![alpha_logup]; + let mut mus = Vec::new(); + let mut ris = Vec::new(); + for (round_idx, round_msgs) in chip_proof.tower_proof.proofs.iter().enumerate() { + ts.append_message(&(round_idx + 1).to_le_bytes()); + ts.append_message(&3usize.to_le_bytes()); + for msg in round_msgs { + for eval in &msg.evaluations { + ts.append_field_element_ext(eval); + } + ris.push(ts.sample_and_append_challenge(b"Internal round").elements); + } + for rounds in &chip_proof.tower_proof.prod_specs_eval { + if let Some(evals) = rounds.get(round_idx) { + ts.append_field_element_exts(evals); + } + } + for rounds in &chip_proof.tower_proof.logup_specs_eval { + if let Some(evals) = rounds.get(round_idx) { + ts.append_field_element_exts(evals); + } + } + mus.push(ts.sample_and_append_vec(b"merge", 1)[0]); + let next_lambda = ::sumcheck::util::get_challenge_pows::( + chip_proof.r_out_evals.len() + + chip_proof.w_out_evals.len() + + 2 * chip_proof.lk_out_evals.len(), + ts, + ) + .get(1) + .copied() + .unwrap_or(RecursionField::ONE); + lambdas.push(next_lambda); + } + TowerTranscriptSchedule { + alpha_logup, + beta, + lambdas, + mus, + ris, + } + } + + fn manual_first_expected(chip_proof: &ZKVMChipProof, lambda: EF, eq: EF) -> EF { + let prod_count = chip_proof.r_out_evals.len() + chip_proof.w_out_evals.len(); + let mut total = EF::ZERO; + let mut pow = EF::ONE; + for rounds in &chip_proof.tower_proof.prod_specs_eval { + if let Some(evals) = rounds.first() { + total += pow * evals.iter().copied().product::(); + } + pow *= lambda; + } + debug_assert_eq!(prod_count, chip_proof.tower_proof.prod_specs_eval.len()); + for rounds in &chip_proof.tower_proof.logup_specs_eval { + if let Some(evals) = rounds.first() { + let (p1, p2, q1, q2) = (evals[0], evals[1], evals[2], evals[3]); + total += pow * (p1 * q2 + p2 * q1); + pow *= lambda; + total += pow * (q1 * q2); + pow *= lambda; + } + } + eq * total + } + + #[test] + #[ignore] + fn debug_chip_15_tower() { + let Some((proof, vk)) = load_fixture() else { + return; + }; + + let target_fork = 10usize; + let (chip_idx, chip_proof) = proof + .chip_proofs + .iter() + .flat_map(|(chip_idx, proofs)| { + proofs.iter().map(move |chip_proof| (*chip_idx, chip_proof)) + }) + .nth(target_fork) + .expect("target fork should exist"); + assert_eq!(chip_idx, 15); + let circuit_vk = circuit_vk_for_idx(&vk, chip_idx).unwrap(); + + let mut basic = BasicTranscript::::new(b"riscv"); + observe_basic_prefix(&mut basic, &vk, &proof); + let basic_alpha = basic.read_challenge().elements; + let basic_beta = basic.read_challenge().elements; + + let mut basic_fork = BasicTranscript::::new(b"fork"); + basic_fork.append_field_element_ext(&basic_alpha); + basic_fork.append_field_element_ext(&basic_beta); + basic_fork.append_field_element(&F::from_usize(target_fork)); + basic_fork.append_field_element(&F::from_usize(chip_idx)); + for num_instance in &chip_proof.num_instances { + basic_fork.append_field_element(&F::from_usize(*num_instance)); + } + let basic_schedule = observe_basic_tower(&mut basic_fork, chip_proof); + + let num_instances: usize = chip_proof.num_instances.iter().sum(); + let mut num_vars = ceil_log2(next_pow2_instance_padding(num_instances)); + if circuit_vk.get_cs().has_ecc_ops() { + num_vars += 1; + } + num_vars += circuit_vk.get_cs().rotation_vars().unwrap_or(0); + let num_batched = chip_proof.r_out_evals.len() + + chip_proof.w_out_evals.len() + + chip_proof.lk_out_evals.len(); + + let eq0 = basic_schedule.beta * basic_schedule.ris[0] + + (EF::ONE - basic_schedule.beta) * (EF::ONE - basic_schedule.ris[0]); + eprintln!( + "chip_idx={chip_idx} fork={target_fork} num_vars={num_vars} num_batched={num_batched} r={} w={} lk={} proofs={} prod_specs={} logup_specs={} lambda0={:?} beta0={:?} ri0={:?} manual_expected={:?}", + chip_proof.r_out_evals.len(), + chip_proof.w_out_evals.len(), + chip_proof.lk_out_evals.len(), + chip_proof.tower_proof.proofs.len(), + chip_proof.tower_proof.prod_specs_eval.len(), + chip_proof.tower_proof.logup_specs_eval.len(), + limbs(basic_schedule.alpha_logup), + limbs(basic_schedule.beta), + limbs(basic_schedule.ris[0]), + limbs(manual_first_expected( + chip_proof, + basic_schedule.alpha_logup, + eq0 + )), + ); + + let replay = replay_tower_proof_poseidon(chip_proof, circuit_vk, &basic_schedule).unwrap(); + if let Some(layer0) = replay.layers.first() { + eprintln!( + "poseidon replay claim_in={:?} claim_out={:?} eq={:?}", + limbs(layer0.claim_in), + limbs(layer0.claim_out), + limbs(layer0.eq_at_r), + ); + } + + let mut basic_verify = BasicTranscript::::new(b"fork"); + basic_verify.append_field_element_ext(&basic_alpha); + basic_verify.append_field_element_ext(&basic_beta); + basic_verify.append_field_element(&F::from_usize(target_fork)); + basic_verify.append_field_element(&F::from_usize(chip_idx)); + for num_instance in &chip_proof.num_instances { + basic_verify.append_field_element(&F::from_usize(*num_instance)); + } + for eval in chip_proof + .r_out_evals + .iter() + .chain(chip_proof.w_out_evals.iter()) + .chain(chip_proof.lk_out_evals.iter()) + .flatten() + { + basic_verify.append_field_element_ext(eval); + } + let tower_verify_result = TowerVerify::verify( + chip_proof + .r_out_evals + .iter() + .cloned() + .chain(chip_proof.w_out_evals.iter().cloned()) + .collect(), + chip_proof.lk_out_evals.clone(), + &chip_proof.tower_proof, + vec![num_vars; num_batched], + NUM_FANIN, + &mut basic_verify, + ); + eprintln!( + "native TowerVerify result={:?}", + tower_verify_result.as_ref().map(|_| ()) + ); + } + + #[test] + #[ignore] + fn debug_compare_all_tower_schedules() { + let Some((proof, vk)) = load_fixture() else { + return; + }; + + let mut basic = BasicTranscript::::new(b"riscv"); + observe_basic_prefix(&mut basic, &vk, &proof); + let basic_alpha = basic.read_challenge().elements; + let basic_beta = basic.read_challenge().elements; + + let mut sponge = default_duplex_sponge_recorder(); + transcript_observe_label(&mut sponge, TranscriptLabel::Riscv.as_bytes()); + let mut openvm_preflight = Preflight::default(); + super::super::circuit::inner::vm_pvs::run_preflight( + &vk, + &proof, + &mut openvm_preflight, + &mut sponge, + ); + let openvm_alpha = openvm_preflight.vm_pvs.lookup_challenge_alpha; + let openvm_beta = openvm_preflight.vm_pvs.lookup_challenge_beta; + + eprintln!( + "global basic alpha={:?} beta={:?}; openvm alpha={:?} beta={:?}", + limbs(basic_alpha), + limbs(basic_beta), + limbs(openvm_alpha), + limbs(openvm_beta), + ); + + let mut checked = 0usize; + let mut mismatches = 0usize; + for (fork_id, (&chip_idx, chip_proof)) in proof + .chip_proofs + .iter() + .flat_map(|(chip_idx, proofs)| { + proofs.iter().map(move |chip_proof| (chip_idx, chip_proof)) + }) + .enumerate() + { + let mut basic_fork = BasicTranscript::::new(b"fork"); + basic_fork.append_field_element_ext(&basic_alpha); + basic_fork.append_field_element_ext(&basic_beta); + basic_fork.append_field_element(&F::from_usize(fork_id)); + basic_fork.append_field_element(&F::from_usize(chip_idx)); + for num_instance in &chip_proof.num_instances { + basic_fork.append_field_element(&F::from_usize(*num_instance)); + } + let basic_schedule = observe_basic_tower(&mut basic_fork, chip_proof); + + let mut openvm_fork = default_duplex_sponge_recorder(); + transcript_observe_label(&mut openvm_fork, TranscriptLabel::Fork.as_bytes()); + FiatShamirTranscript::::observe_ext( + &mut openvm_fork, + openvm_alpha, + ); + FiatShamirTranscript::::observe_ext( + &mut openvm_fork, + openvm_beta, + ); + FiatShamirTranscript::::observe( + &mut openvm_fork, + F::from_usize(fork_id), + ); + FiatShamirTranscript::::observe( + &mut openvm_fork, + F::from_usize(chip_idx), + ); + for num_instance in &chip_proof.num_instances { + FiatShamirTranscript::::observe( + &mut openvm_fork, + F::from_usize(*num_instance), + ); + } + let openvm_schedule = record_gkr_transcript(&mut openvm_fork, chip_idx, chip_proof); + + let same = basic_schedule.alpha_logup == openvm_schedule.alpha_logup + && basic_schedule.beta == openvm_schedule.beta + && basic_schedule.lambdas == openvm_schedule.lambdas + && basic_schedule.mus == openvm_schedule.mus + && basic_schedule.ris == openvm_schedule.ris; + if !same { + mismatches += 1; + eprintln!( + "schedule mismatch fork={fork_id} chip_idx={chip_idx} basic lambda0={:?} beta0={:?} ri0={:?}; openvm lambda0={:?} beta0={:?} ri0={:?}", + limbs(basic_schedule.alpha_logup), + limbs(basic_schedule.beta), + basic_schedule.ris.first().copied().map(limbs), + limbs(openvm_schedule.alpha_logup), + limbs(openvm_schedule.beta), + openvm_schedule.ris.first().copied().map(limbs), + ); + } + checked += 1; + } + + eprintln!("checked {checked} fork schedules, mismatches={mismatches}"); + assert_eq!(mismatches, 0); + } + + #[test] + #[ignore] + fn debug_compare_tower_transcripts() { + let Some((proof, vk)) = load_fixture() else { + return; + }; + let (&chip_idx, chip_instances) = proof.chip_proofs.iter().next().unwrap(); + let chip_proof = &chip_instances[0]; + let circuit_vk = circuit_vk_for_idx(&vk, chip_idx).unwrap(); + + let mut basic = BasicTranscript::::new(b"riscv"); + observe_basic_prefix(&mut basic, &vk, &proof); + let basic_alpha = basic.read_challenge().elements; + let basic_beta = basic.read_challenge().elements; + + let mut basic_fork = BasicTranscript::::new(b"fork"); + basic_fork.append_field_element_ext(&basic_alpha); + basic_fork.append_field_element_ext(&basic_beta); + basic_fork.append_field_element(&F::from_usize(0)); + basic_fork.append_field_element(&F::from_usize(chip_idx)); + for num_instance in &chip_proof.num_instances { + basic_fork.append_field_element(&F::from_usize(*num_instance)); + } + let basic_schedule = observe_basic_tower(&mut basic_fork, chip_proof); + + let mut sponge = default_duplex_sponge_recorder(); + transcript_observe_label(&mut sponge, TranscriptLabel::Riscv.as_bytes()); + let mut openvm_preflight = Preflight::default(); + super::super::circuit::inner::vm_pvs::run_preflight( + &vk, + &proof, + &mut openvm_preflight, + &mut sponge, + ); + let openvm_alpha = openvm_preflight.vm_pvs.lookup_challenge_alpha; + let openvm_beta = openvm_preflight.vm_pvs.lookup_challenge_beta; + let mut openvm_fork = default_duplex_sponge_recorder(); + transcript_observe_label(&mut openvm_fork, TranscriptLabel::Fork.as_bytes()); + FiatShamirTranscript::::observe_ext( + &mut openvm_fork, + openvm_alpha, + ); + FiatShamirTranscript::::observe_ext(&mut openvm_fork, openvm_beta); + FiatShamirTranscript::::observe( + &mut openvm_fork, + F::from_usize(0), + ); + FiatShamirTranscript::::observe( + &mut openvm_fork, + F::from_usize(chip_idx), + ); + for num_instance in &chip_proof.num_instances { + FiatShamirTranscript::::observe( + &mut openvm_fork, + F::from_usize(*num_instance), + ); + } + let openvm_schedule = record_gkr_transcript(&mut openvm_fork, chip_idx, chip_proof); + + eprintln!( + "basic alpha={:?} beta={:?} tower_lambda0={:?} beta0={:?} ri0={:?}", + limbs(basic_alpha), + limbs(basic_beta), + limbs(basic_schedule.alpha_logup), + limbs(basic_schedule.beta), + limbs(basic_schedule.ris[0]), + ); + eprintln!( + "openvm alpha={:?} beta={:?} tower_lambda0={:?} beta0={:?} ri0={:?}", + limbs(openvm_alpha), + limbs(openvm_beta), + limbs(openvm_schedule.alpha_logup), + limbs(openvm_schedule.beta), + limbs(openvm_schedule.ris[0]), + ); + let eq0 = basic_schedule.beta * basic_schedule.ris[0] + + (EF::ONE - basic_schedule.beta) * (EF::ONE - basic_schedule.ris[0]); + eprintln!( + "manual expected cur={:?} next={:?} eq={:?}", + limbs(manual_first_expected( + chip_proof, + basic_schedule.lambdas[0], + eq0 + )), + limbs(manual_first_expected( + chip_proof, + basic_schedule.lambdas[1], + eq0 + )), + limbs(eq0), + ); + + let mut basic_verify = BasicTranscript::::new(b"fork"); + basic_verify.append_field_element_ext(&basic_alpha); + basic_verify.append_field_element_ext(&basic_beta); + basic_verify.append_field_element(&F::from_usize(0)); + basic_verify.append_field_element(&F::from_usize(chip_idx)); + for num_instance in &chip_proof.num_instances { + basic_verify.append_field_element(&F::from_usize(*num_instance)); + } + for eval in chip_proof + .r_out_evals + .iter() + .chain(chip_proof.w_out_evals.iter()) + .chain(chip_proof.lk_out_evals.iter()) + .flatten() + { + basic_verify.append_field_element_ext(eval); + } + let num_instances: usize = chip_proof.num_instances.iter().sum(); + let mut num_vars = ceil_log2(next_pow2_instance_padding(num_instances)); + if circuit_vk.get_cs().has_ecc_ops() { + num_vars += 1; + } + num_vars += circuit_vk.get_cs().rotation_vars().unwrap_or(0); + let num_batched = chip_proof.r_out_evals.len() + + chip_proof.w_out_evals.len() + + chip_proof.lk_out_evals.len(); + let tower_verify_result = TowerVerify::verify( + chip_proof + .r_out_evals + .iter() + .cloned() + .chain(chip_proof.w_out_evals.iter().cloned()) + .collect(), + chip_proof.lk_out_evals.clone(), + &chip_proof.tower_proof, + vec![num_vars; num_batched], + NUM_FANIN, + &mut basic_verify, + ); + eprintln!( + "native TowerVerify result={:?}", + tower_verify_result.as_ref().map(|_| ()) + ); + } +} + // To reduce the number of structs and trait implementations, we collect them into a single enum // with enum dispatch. #[derive(strum_macros::Display, strum::EnumDiscriminants)] @@ -1080,14 +1836,10 @@ impl RowMajorChip for TowerModuleChip { ) -> Option> { use TowerModuleChip::*; match self { - Input => TowerInputTraceGenerator.generate_trace( - &(&blob.input_records, &blob.proof_q0_claims), - required_height, - ), - Layer => TowerLayerTraceGenerator.generate_trace( - &(&blob.layer_records, &blob.mus_records, &blob.q0_claims), - required_height, - ), + Input => TowerInputTraceGenerator + .generate_trace(&blob.input_records.as_slice(), required_height), + Layer => TowerLayerTraceGenerator + .generate_trace(&(&blob.layer_records, &blob.mus_records), required_height), ProdReadClaim => TowerProdReadSumCheckClaimTraceGenerator.generate_trace( &(&blob.layer_records, &blob.tower_records, &blob.mus_records), required_height, diff --git a/ceno_recursion_v2/src/tower/sumcheck/air.rs b/ceno_recursion_v2/src/tower/sumcheck/air.rs index 39c0bdec6..62378bbf2 100644 --- a/ceno_recursion_v2/src/tower/sumcheck/air.rs +++ b/ceno_recursion_v2/src/tower/sumcheck/air.rs @@ -15,7 +15,7 @@ use crate::tower::bus::{ TowerSumcheckInputMessage, TowerSumcheckOutputBus, TowerSumcheckOutputMessage, }; use recursion_circuit::{ - bus::{TranscriptBus, XiRandomnessBus, XiRandomnessMessage}, + bus::{TranscriptBus, XiRandomnessBus}, utils::{ assert_one_ext, ext_field_add, ext_field_multiply, ext_field_multiply_scalar, ext_field_one_minus, ext_field_subtract, @@ -29,6 +29,7 @@ pub struct TowerLayerSumcheckCols { pub is_enabled: T, pub proof_idx: T, pub idx: T, + pub chip_id: T, pub layer_idx: T, pub is_first_idx: T, pub is_first_layer: T, @@ -289,7 +290,9 @@ where ); // Transcript index increment - use crate::tower::tower_transcript_len::ROUND_LEN; + use crate::tower::tower_transcript_len::{ + LABEL_INTERNAL_ROUND, LABEL_INTERNAL_ROUND_VALUES, ROUND_LEN, + }; builder.when(is_transition_round.clone()).assert_eq( next.tidx, local.tidx.into() + AB::Expr::from_usize(ROUND_LEN), @@ -307,7 +310,7 @@ where builder, local.proof_idx, TowerSumcheckInputMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), is_last_layer: local.is_last_layer.into(), tidx: local.tidx.into(), @@ -321,7 +324,7 @@ where builder, local.proof_idx, TowerSumcheckOutputMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), tidx: local.tidx.into() + AB::Expr::from_usize(ROUND_LEN), claim_out: local.claim_out.map(Into::into), @@ -336,7 +339,7 @@ where builder, local.proof_idx, TowerSumcheckChallengeMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx - AB::Expr::ONE, sumcheck_round: local.round.into(), challenge: local.prev_challenge.map(Into::into), @@ -348,9 +351,9 @@ where builder, local.proof_idx, TowerSumcheckChallengeMessage { - idx: local.idx.into(), + chip_id: local.chip_id.into(), layer_idx: local.layer_idx.into(), - sumcheck_round: local.round.into() + AB::Expr::ONE, + sumcheck_round: local.round.into(), challenge: local.challenge.map(Into::into), }, local.is_enabled * (AB::Expr::ONE - local.is_last_layer) * is_not_dummy.clone(), @@ -373,6 +376,16 @@ where ); tidx += AB::Expr::from_usize(D_EF); } + for (i, value) in LABEL_INTERNAL_ROUND_VALUES.iter().enumerate() { + self.transcript_bus.observe( + builder, + local.proof_idx, + tidx.clone() + AB::Expr::from_usize(i), + AB::Expr::from_usize(*value), + local.is_enabled * is_not_dummy.clone(), + ); + } + tidx += AB::Expr::from_usize(LABEL_INTERNAL_ROUND); // 1b. Sample challenge `ri` self.transcript_bus.sample_ext( builder, @@ -382,17 +395,7 @@ where local.is_enabled * is_not_dummy.clone(), ); - // 2. XiRandomnessBus - // 2a. Send last challenge - self.xi_randomness_bus.send( - builder, - local.proof_idx, - XiRandomnessMessage { - idx: local.round + AB::Expr::ONE, - xi: local.challenge.map(Into::into), - }, - local.is_enabled * local.is_last_layer * is_not_dummy.clone(), - ); + let _ = &self.xi_randomness_bus; } } diff --git a/ceno_recursion_v2/src/tower/sumcheck/trace.rs b/ceno_recursion_v2/src/tower/sumcheck/trace.rs index 9d84a04a3..289ac2141 100644 --- a/ceno_recursion_v2/src/tower/sumcheck/trace.rs +++ b/ceno_recursion_v2/src/tower/sumcheck/trace.rs @@ -12,8 +12,10 @@ use crate::{tower::tower_transcript_len, tracegen::RowMajorChip}; pub struct TowerSumcheckRecord { pub proof_idx: usize, pub idx: usize, + pub chip_id: usize, pub is_first_air_idx: bool, pub tidx: usize, + pub layer_tidxs: Vec, pub evals: Vec<[EF; 3]>, pub ris: Vec, pub claims: Vec, @@ -43,22 +45,24 @@ impl TowerSumcheckRecord { #[inline] fn derive_tidx(&self, layer_idx: usize, round_in_layer: usize) -> usize { - let rounds_before_layer = Self::layer_start_index(layer_idx); - self.tidx - + tower_transcript_len::ROUND_LEN * (rounds_before_layer + round_in_layer) - + tower_transcript_len::LAYER_GAP_LEN * layer_idx + self.layer_tidxs + .get(layer_idx) + .copied() + .unwrap_or(self.tidx) + + tower_transcript_len::ROUND_LEN * round_in_layer } #[inline] pub fn prev_challenge(layer_idx: usize, round_in_layer: usize, mus: &[EF], ris: &[EF]) -> EF { - if round_in_layer == 0 { - mus[layer_idx] - } else { - let prev_layer = layer_idx - .checked_sub(1) - .expect("round_in_layer > 0 only occurs for non-root layers"); - let offset = Self::layer_start_index(prev_layer) + (round_in_layer - 1); + if layer_idx == 0 { + debug_assert_eq!(round_in_layer, 0); + mus[0] + } else if round_in_layer < layer_idx { + let prev_layer = layer_idx - 1; + let offset = Self::layer_start_index(prev_layer) + round_in_layer; ris[offset] + } else { + mus[layer_idx] } } @@ -145,6 +149,7 @@ impl RowMajorChip for TowerSumcheckTraceGenerator { cols.tidx = F::from_usize(D_EF); cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.layer_idx = F::ONE; cols.is_first_round = F::ONE; cols.is_first_idx = F::from_bool(record.is_first_air_idx); @@ -215,6 +220,7 @@ impl RowMajorChip for TowerSumcheckTraceGenerator { cols.is_enabled = F::ONE; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.chip_id = F::from_usize(record.chip_id); cols.layer_idx = F::from_usize(layer_idx_value); cols.is_last_layer = F::from_bool(is_last_layer); diff --git a/ceno_recursion_v2/src/tower/tower.rs b/ceno_recursion_v2/src/tower/tower.rs index 0d288adf9..5b78818f8 100644 --- a/ceno_recursion_v2/src/tower/tower.rs +++ b/ceno_recursion_v2/src/tower/tower.rs @@ -40,6 +40,7 @@ pub struct TowerReplayResult { pub layers: Vec, } +#[allow(dead_code)] pub fn replay_tower_proof( chip_proof: &ZKVMChipProof, vk: &VerifyingKey, diff --git a/ceno_recursion_v2/src/transcript/mod.rs b/ceno_recursion_v2/src/transcript/mod.rs index 1588acd78..030f74ee2 100644 --- a/ceno_recursion_v2/src/transcript/mod.rs +++ b/ceno_recursion_v2/src/transcript/mod.rs @@ -5,7 +5,7 @@ use openvm_cpu_backend::CpuBackend; use openvm_poseidon2_air::{POSEIDON2_WIDTH, Poseidon2Config, Poseidon2SubChip}; use openvm_stark_backend::{AirRef, StarkProtocolConfig, prover::AirProvingContext}; use openvm_stark_sdk::{ - config::baby_bear_poseidon2::{F, poseidon2_perm}, + config::baby_bear_poseidon2::{D_EF, F, poseidon2_perm}, p3_baby_bear::Poseidon2BabyBear, }; use p3_air::BaseAir; @@ -13,8 +13,12 @@ use p3_field::{PrimeCharacteristicRing, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_symmetric::Permutation; -use crate::system::{ - AirModule, BusInventory, GlobalCtxCpu, Preflight, RecursionProof, RecursionVk, TraceGenModule, +use crate::{ + system::{ + AirModule, BusInventory, GlobalCtxCpu, Preflight, RecursionProof, RecursionVk, + TraceGenModule, + }, + utils::TranscriptLabel, }; use recursion_circuit::transcript::poseidon2::{CHUNK, Poseidon2Air, Poseidon2Cols}; @@ -98,12 +102,20 @@ impl TranscriptModule { is_fork_start: bool, initial_state: [F; POSEIDON2_WIDTH], tidx_offset: usize, + global_export_ranges: Option<&[(usize, usize)]>, poseidon2_perm_inputs: &mut Vec<[F; POSEIDON2_WIDTH]>, ) -> [F; POSEIDON2_WIDTH] { let mut tidx = 0usize; let mut prev_poseidon_state = initial_state; + let fork_prelude_len = TranscriptLabel::Fork.field_len() + 2 * D_EF + 4; + let fork_sample_start = if is_fork_start { + log.len().saturating_sub(D_EF) + } else { + usize::MAX + }; for (i, row) in trace.chunks_exact_mut(transcript_width).enumerate() { + let row_start_tidx = tidx; let cols: &mut ForkedTranscriptCols = row.borrow_mut(); cols.proof_idx = F::from_usize(proof_idx); cols.fork_id = F::from_usize(fork_id); @@ -117,8 +129,16 @@ impl TranscriptModule { let is_sample = log.samples()[tidx]; cols.is_sample = F::from_bool(is_sample); - cols.tidx = F::from_usize(tidx + tidx_offset); + cols.tidx = F::from_usize(row_start_tidx); + cols.global_tidx = F::from_usize(tidx + tidx_offset); + cols.is_fork = F::from_bool(is_fork_start); cols.mask[0] = F::ONE; + cols.global_bus_mask[0] = + F::from_bool(should_export_global(row_start_tidx, global_export_ranges)); + cols.forked_bus_mask[0] = F::from_bool( + is_fork_start + && (row_start_tidx < fork_prelude_len || row_start_tidx >= fork_sample_start), + ); cols.prev_state = prev_poseidon_state; if is_sample { @@ -141,6 +161,12 @@ impl TranscriptModule { } cols.mask[idx] = F::ONE; + let op_tidx = row_start_tidx + idx; + cols.global_bus_mask[idx] = + F::from_bool(should_export_global(op_tidx, global_export_ranges)); + cols.forked_bus_mask[idx] = F::from_bool( + is_fork_start && (op_tidx < fork_prelude_len || op_tidx >= fork_sample_start), + ); if is_sample { debug_assert_eq!(cols.prev_state[CHUNK - 1 - idx], log.values()[tidx]); } else { @@ -217,6 +243,7 @@ impl TranscriptModule { let info = &proof_infos[pidx]; // Fill trunk rows (fork_id = 0, tidx_offset = 0). + let trunk_global_export_ranges = trunk_global_export_ranges(preflight); let trunk_end = offset + info.trunk_rows; let trunk_slice = &mut transcript_trace[offset * transcript_width..trunk_end * transcript_width]; @@ -230,12 +257,14 @@ impl TranscriptModule { false, // is_fork_start [F::ZERO; POSEIDON2_WIDTH], // trunk starts with zero state 0, // tidx_offset: trunk starts at global tidx 0 + Some(trunk_global_export_ranges.as_slice()), &mut poseidon2_perm_inputs, ); offset = trunk_end; - // Fill fork rows with fork-local tidx offsets. + // Fill fork rows with fork-local tidx and global TranscriptBus offsets. for (fi, fork_log) in preflight.fork_transcripts.iter().enumerate() { + let global_export_ranges = fork_global_export_ranges(preflight, fi); let fork_rows = info.fork_rows[fi]; let fork_end = offset + fork_rows; let fork_slice = @@ -251,7 +280,8 @@ impl TranscriptModule { false, // is_proof_start true, // is_fork_start [F::ZERO; POSEIDON2_WIDTH], - 0, + preflight.fork_global_offset(fi), + Some(global_export_ranges.as_slice()), &mut poseidon2_perm_inputs, ); offset = fork_end; @@ -313,6 +343,69 @@ impl TranscriptModule { } } +fn should_export_global(tidx: usize, ranges: Option<&[(usize, usize)]>) -> bool { + match ranges { + None => true, + Some(ranges) => ranges + .iter() + .any(|&(start, end)| start <= tidx && tidx < end), + } +} + +fn fork_global_export_ranges(preflight: &Preflight, fork_idx: usize) -> Vec<(usize, usize)> { + let Some(main_tidx) = preflight + .main + .chips + .iter() + .find(|entry| entry.fork_idx == fork_idx) + .map(|entry| entry.tidx) + else { + return Vec::new(); + }; + + preflight + .gkr + .chips + .iter() + .filter(|entry| entry.fork_idx == fork_idx) + .filter(|entry| entry.num_layers > 0) + .flat_map(|entry| [(entry.tidx, main_tidx), (main_tidx, main_tidx + D_EF)]) + .collect() +} + +fn trunk_global_export_ranges(preflight: &Preflight) -> Vec<(usize, usize)> { + let log_len = preflight.transcript.len(); + if preflight.batch_constraint.lambda_tidx == 0 + && preflight.batch_constraint.tidx_before_univariate == 0 + { + return if log_len == 0 { + Vec::new() + } else { + vec![(0, log_len)] + }; + } + + let lambda_start = preflight.batch_constraint.lambda_tidx.min(log_len); + let lambda_end = (lambda_start + D_EF).min(log_len); + let mu_end = preflight + .batch_constraint + .tidx_before_univariate + .min(log_len); + let mu_start = mu_end.saturating_sub(D_EF); + + let mut ranges = Vec::new(); + push_nonempty_range(&mut ranges, 0, lambda_start); + push_nonempty_range(&mut ranges, lambda_end, mu_start); + push_nonempty_range(&mut ranges, mu_end, log_len); + ranges +} + +fn push_nonempty_range(ranges: &mut Vec<(usize, usize)>, start: usize, end: usize) { + if start < end { + ranges.push((start, end)); + } +} + impl AirModule for TranscriptModule { fn num_airs(&self) -> usize { 2 diff --git a/ceno_recursion_v2/src/transcript/transcript_air.rs b/ceno_recursion_v2/src/transcript/transcript_air.rs index ca12a2c28..b55b802ae 100644 --- a/ceno_recursion_v2/src/transcript/transcript_air.rs +++ b/ceno_recursion_v2/src/transcript/transcript_air.rs @@ -20,7 +20,6 @@ use openvm_circuit_primitives::{ use openvm_stark_backend::{ BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, }; -use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{Field, PrimeCharacteristicRing}; use p3_matrix::Matrix; @@ -46,11 +45,18 @@ pub struct ForkedTranscriptCols { pub proof_idx: T, pub is_proof_start: T, + /// Fork-local transcript index. For trunk rows this is equal to `global_tidx`. pub tidx: T, + /// Absolute transcript index in the flattened proof transcript namespace. + pub global_tidx: T, /// Indicator for sample/observe. pub is_sample: T, /// 0/1 indicators for positions being absorbed/squeezed. pub mask: [T; CHUNK], + /// 0/1 indicators for positions exported to the global TranscriptBus. + pub global_bus_mask: [T; CHUNK], + /// 0/1 indicators for fork-local transcript positions exported to `ForkedTranscriptBus`. + pub forked_bus_mask: [T; CHUNK], /// The poseidon2 state. pub prev_state: [T; POSEIDON2_WIDTH], @@ -59,18 +65,22 @@ pub struct ForkedTranscriptCols { // --- fork extensions --- /// 1 on the first row of a forked transcript chain. pub is_fork_start: T, + /// 1 for rows belonging to a forked transcript chain. + pub is_fork: T, /// Fork identifier (0-based across forked chip transcripts). pub fork_id: T, } impl ForkedTranscriptCols { pub const fn width() -> usize { - // proof_idx, is_proof_start, tidx, is_sample = 4 + // proof_idx, is_proof_start, tidx, global_tidx, is_sample = 5 // mask = CHUNK + // global_bus_mask = CHUNK + // forked_bus_mask = CHUNK // prev_state = POSEIDON2_WIDTH // post_state = POSEIDON2_WIDTH - // is_fork_start, fork_id = 2 - 4 + CHUNK + 2 * POSEIDON2_WIDTH + 2 + // is_fork_start, is_fork, fork_id = 3 + 5 + 3 * CHUNK + 2 * POSEIDON2_WIDTH + 3 } } @@ -127,6 +137,7 @@ impl Air for ForkedTranscriptAir { // is_proof_start and is_fork_start are mutually exclusive booleans builder.assert_bool(local.is_proof_start); builder.assert_bool(local.is_fork_start); + builder.assert_bool(local.is_fork); // A row is a "chain start" if either is_proof_start or is_fork_start let is_chain_start: AB::Expr = local.is_proof_start.into() + local.is_fork_start.into(); // At most one of these can be 1 @@ -157,15 +168,22 @@ impl Air for ForkedTranscriptAir { // When is_proof_start: tidx = 0, sponge state = 0 (trunk start) builder.when(local.is_proof_start).assert_zero(local.tidx); + builder + .when(local.is_proof_start) + .assert_zero(local.global_tidx); builder.when(local.is_proof_start).assert_one(is_valid); builder .when(local.is_proof_start) .assert_zero(local.fork_id); + builder + .when(local.is_proof_start) + .assert_zero(local.is_fork); builder.assert_bool(local.is_sample); - // When is_fork_start: fork chain begins (tidx is NOT zero; it's the - // fork's global tidx offset). Only constrain validity. + // When is_fork_start: fork chain begins at fork-local tidx 0. builder.when(local.is_fork_start).assert_one(is_valid); + builder.when(local.is_fork_start).assert_one(local.is_fork); + builder.when(local.is_fork_start).assert_zero(local.tidx); // Initial state for proof start (trunk): all-zero sponge for i in 0..CHUNK { @@ -187,6 +205,18 @@ impl Air for ForkedTranscriptAir { let mut count = AB::Expr::ZERO; for i in 0..CHUNK { builder.assert_bool(local.mask[i]); + builder.assert_bool(local.global_bus_mask[i]); + builder.assert_bool(local.forked_bus_mask[i]); + builder + .when(local.global_bus_mask[i]) + .assert_one(local.mask[i]); + builder + .when(local.forked_bus_mask[i]) + .assert_one(local.is_fork); + builder + .when(local.forked_bus_mask[i]) + .assert_one(local.mask[i]); + builder.assert_zero(local.global_bus_mask[i] * local.forked_bus_mask[i]); count += local.mask[i].into(); let skip = local.mask[i] - AB::Expr::ONE; @@ -212,6 +242,9 @@ impl Air for ForkedTranscriptAir { builder .when(local_next_same_chain.clone()) .assert_eq(next.tidx, local.tidx + count.clone()); + builder + .when(local_next_same_chain.clone()) + .assert_eq(next.global_tidx, local.global_tidx + count.clone()); // If local.is_sample == next.is_sample within the same chain, // there must be exactly CHUNK operations. @@ -224,83 +257,66 @@ impl Air for ForkedTranscriptAir { builder .when(local_next_same_chain.clone()) .assert_eq(local.fork_id, next.fork_id); + builder + .when(local_next_same_chain.clone()) + .assert_eq(local.is_fork, next.is_fork); /////////////////////////////////////////////////////////////////////// // Transcript bus interactions (send) /////////////////////////////////////////////////////////////////////// for i in 0..CHUNK { let observe_message = TranscriptBusMessage { - tidx: local.tidx + AB::Expr::from_usize(i), + tidx: local.global_tidx + AB::Expr::from_usize(i), value: local.prev_state[i].into(), is_sample: AB::Expr::ZERO, }; let sample_message = TranscriptBusMessage { - tidx: local.tidx + AB::Expr::from_usize(i), + tidx: local.global_tidx + AB::Expr::from_usize(i), value: local.prev_state[CHUNK - 1 - i].into(), is_sample: AB::Expr::ONE, }; + let transcript_mult = local.global_bus_mask[i]; self.transcript_bus.send( builder, local.proof_idx, observe_message, - local.mask[i] * (AB::Expr::ONE - local.is_sample), + transcript_mult.clone() * (AB::Expr::ONE - local.is_sample), ); self.transcript_bus.send( builder, local.proof_idx, sample_message, - local.mask[i] * local.is_sample, + transcript_mult * local.is_sample, ); } /////////////////////////////////////////////////////////////////////// - // Forked transcript bus interactions (send fork state) + // Forked transcript bus interactions (send fork-local operations) /////////////////////////////////////////////////////////////////////// - // On is_fork_start rows, send fork-local transcript words with fork_id. - for i in 0..D_EF { - self.forked_transcript_bus.send( - builder, - local.proof_idx, - ForkedTranscriptBusMessage { - fork_id: local.fork_id.into(), - tidx: local.tidx + AB::Expr::from_usize(i), - value: local.prev_state[i].into(), - is_sample: AB::Expr::ZERO, - }, - local.is_fork_start, - ); - self.forked_transcript_bus.send( - builder, - local.proof_idx, - ForkedTranscriptBusMessage { - fork_id: local.fork_id.into(), - tidx: local.tidx + AB::Expr::from_usize(D_EF + i), - value: local.prev_state[i].into(), - is_sample: AB::Expr::ZERO, - }, - local.is_fork_start, - ); + for i in 0..CHUNK { + let observe_message = ForkedTranscriptBusMessage { + fork_id: local.fork_id.into(), + tidx: local.tidx + AB::Expr::from_usize(i), + value: local.prev_state[i].into(), + is_sample: AB::Expr::ZERO, + }; + let sample_message = ForkedTranscriptBusMessage { + fork_id: local.fork_id.into(), + tidx: local.tidx + AB::Expr::from_usize(i), + value: local.prev_state[CHUNK - 1 - i].into(), + is_sample: AB::Expr::ONE, + }; self.forked_transcript_bus.send( builder, local.proof_idx, - ForkedTranscriptBusMessage { - fork_id: local.fork_id.into(), - tidx: local.tidx + AB::Expr::from_usize(2 * D_EF + i), - value: local.prev_state[i].into(), - is_sample: AB::Expr::ONE, - }, - local.is_fork_start, + observe_message, + local.forked_bus_mask[i] * (AB::Expr::ONE - local.is_sample), ); self.forked_transcript_bus.send( builder, local.proof_idx, - ForkedTranscriptBusMessage { - fork_id: local.fork_id.into(), - tidx: local.tidx + AB::Expr::from_usize(3 * D_EF + i), - value: local.prev_state[i].into(), - is_sample: AB::Expr::ONE, - }, - local.is_fork_start, + sample_message, + local.forked_bus_mask[i] * local.is_sample, ); }