Skip to content

Commit 7e5ecee

Browse files
darth-cykunxian-xiahero78119
authored
Use NativeSumcheck for calculating initial_sum (#1249)
Optimization for Calculating Initial Sum in tower verify. - Remove dot product logic for calculating `initial_sum`. Use `NativeSumcheck` instead. - Flatten out evaluation arrays. - Remove arithmetic utilities no longer required. Optimization for `NativeSumcheck` - Use the additional mode on `NativeSumcheck` chip that allows passing in hint space IDs for evaluation inputs instead of loading concrete witness arrays. This significantly reduces cycles involved in loading witnesses. --------- Co-authored-by: kunxian xia <xiakunxian130@gmail.com> Co-authored-by: Ming <hero78119@gmail.com>
1 parent 8a63ff4 commit 7e5ecee

10 files changed

Lines changed: 411 additions & 380 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ docs/book
1515
# ceno serialized files
1616
*.bin
1717
*.json
18+
*.srs

Cargo.lock

Lines changed: 47 additions & 47 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,16 @@ ceno_gpu = { git = "https://github.com/scroll-tech/ceno-gpu-mock.git", package =
9292
cudarc = { version = "0.17.3", features = ["driver", "cuda-version-from-build-system"] }
9393

9494
# ceno-recursion dependencies
95-
openvm = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
96-
openvm-circuit = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
97-
openvm-continuations = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
98-
openvm-instructions = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
99-
openvm-native-circuit = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
100-
openvm-native-compiler = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
101-
openvm-native-compiler-derive = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
102-
openvm-native-recursion = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
103-
openvm-rv32im-circuit = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
104-
openvm-sdk = { git = "https://github.com/scroll-tech/openvm.git", rev = "ef22e8e", default-features = false }
95+
openvm = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
96+
openvm-circuit = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
97+
openvm-continuations = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
98+
openvm-instructions = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
99+
openvm-native-circuit = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
100+
openvm-native-compiler = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
101+
openvm-native-compiler-derive = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
102+
openvm-native-recursion = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
103+
openvm-rv32im-circuit = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
104+
openvm-sdk = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/v1.4.1-scroll-ext", default-features = false }
105105

106106
openvm-cuda-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.1", default-features = false }
107107
openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.1", default-features = false }

ceno_recursion/src/aggregation/internal.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ impl<C: Config> NonLeafVerifierVariables<C> {
9696
let proof = builder.get(proofs, i);
9797
assert_required_air_for_agg_vm_present(builder, &proof);
9898
let proof_vm_pvs = self.verify_internal_or_leaf_verifier_proof(builder, &proof);
99-
10099
assert_single_segment_vm_exit_successfully(builder, &proof);
101-
102100
builder.if_eq(i, RVar::zero()).then_or_else(
103101
|builder| {
104102
builder.assign(&pvs.app_commit, proof_vm_pvs.vm_verifier_pvs.app_commit);

ceno_recursion/src/aggregation/mod.rs

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,10 @@ pub const INTERNAL_LOG_BLOWUP: usize = 2;
8686
pub const ROOT_LOG_BLOWUP: usize = 3;
8787
pub const SBOX_SIZE: usize = 7;
8888
const VM_MAX_TRACE_HEIGHTS: &[u32] = &[
89-
4194304, 4, 128, 2097152, 8388608, 4194304, 262144, 8388608, 16777216, 2097152, 16777216,
90-
2097152, 8388608, 262144, 2097152, 1048576, 4194304, 1048576, 262144,
89+
4194304, 4, 128, 2097152, 8388608, 4194304, 262144, 8388608, 16777216, 16777216, 2097152,
90+
16777216, 2097152, 8388608, 262144, 2097152, 1048576, 4194304, 1048576, 262144,
9191
];
92+
9293
pub struct CenoAggregationProver {
9394
pub base_vk: ZKVMVerifyingKey<E, Basefold<E, BasefoldRSParams>>,
9495
pub leaf_prover: VmInstance<BabyBearPoseidon2Engine, NativeBuilder>,
@@ -290,7 +291,7 @@ impl CenoAggregationProver {
290291

291292
// _debug: export
292293
// let file =
293-
// File::create(format!("leaf_proof_{:?}.bin", proof_idx)).expect("Create export proof file");
294+
// File::create(format!("leaf_proof_{:?}.bin", proof_idx)).expect("Create export proof file");
294295
// bincode::serialize_into(file, &leaf_proof).expect("failed to serialize leaf proof");
295296

296297
println!(
@@ -304,14 +305,28 @@ impl CenoAggregationProver {
304305
})
305306
.collect::<Vec<_>>();
306307

307-
// Aggregate tree to root proof
308+
// Aggregate leaf proofs into a single internal proof via binary tree
309+
let root_inner = self.aggregate_internal_proofs(leaf_proofs);
310+
311+
// Export e2e stark proof (used in verify_e2e_stark_proof)
312+
VmStarkProof {
313+
inner: root_inner,
314+
user_public_values,
315+
}
316+
}
317+
318+
/// Aggregate leaf (or internal) proofs into a single root internal proof
319+
/// via a binary tree of internal proving rounds.
320+
pub fn aggregate_internal_proofs(&mut self, leaf_proofs: Vec<Proof<SC>>) -> Proof<SC> {
321+
let start = Instant::now();
322+
308323
let mut internal_node_idx = -1;
309324
let mut internal_node_height = 0;
310325
let mut proofs = leaf_proofs;
311326

312327
println!(
313328
"Aggregation - Start internal aggregation at: {:?}",
314-
aggregation_start_timestamp.elapsed()
329+
start.elapsed()
315330
);
316331
// We will always generate at least one internal proof, even if there is only one leaf
317332
// proof, in order to shrink the proof size
@@ -321,7 +336,6 @@ impl CenoAggregationProver {
321336
&proofs,
322337
DEFAULT_NUM_CHILDREN_INTERNAL,
323338
);
324-
325339
let layer_proofs: Vec<Proof<_>> = internal_inputs
326340
.into_iter()
327341
.map(|input| {
@@ -337,7 +351,7 @@ impl CenoAggregationProver {
337351
"Aggregation - Completed internal node (idx: {:?}) at height {:?}: {:?}",
338352
internal_node_idx,
339353
internal_node_height,
340-
aggregation_start_timestamp.elapsed()
354+
start.elapsed()
341355
);
342356

343357
// _debug: export
@@ -356,17 +370,13 @@ impl CenoAggregationProver {
356370
}
357371
println!(
358372
"Aggregation - Completed internal aggregation at: {:?}",
359-
aggregation_start_timestamp.elapsed()
373+
start.elapsed()
360374
);
361375
println!("Aggregation - Final height: {:?}", internal_node_height);
362376

363377
// TODO: generate root proof from last internal proof
364378

365-
// Export e2e stark proof (used in verify_e2e_stark_proof)
366-
VmStarkProof {
367-
inner: proofs.pop().unwrap(),
368-
user_public_values,
369-
}
379+
proofs.pop().unwrap()
370380
}
371381
}
372382

@@ -415,6 +425,25 @@ impl CenoLeafVmVerifierConfig {
415425
builder.assign(&stark_pvs.connector.initial_pc, init_pc);
416426
builder.assign(&stark_pvs.connector.final_pc, end_pc);
417427
builder.assign(&stark_pvs.connector.exit_code, exit_code);
428+
// Internal aggregation asserts connector chaining on this field.
429+
builder
430+
.if_eq(ceno_leaf_input.is_last, Usize::from(1))
431+
.then_or_else(
432+
|builder| {
433+
builder.assign(&stark_pvs.connector.is_terminate, F::ONE);
434+
},
435+
|builder| {
436+
builder.assign(&stark_pvs.connector.is_terminate, F::ZERO);
437+
},
438+
);
439+
440+
// Keep remaining committed PVs deterministic until real memory/public-values
441+
// commitments are wired through this custom leaf program.
442+
for i in 0..DIGEST_SIZE {
443+
builder.assign(&stark_pvs.memory.initial_root[i], F::ZERO);
444+
builder.assign(&stark_pvs.memory.final_root[i], F::ZERO);
445+
builder.assign(&stark_pvs.public_values_commit[i], F::ZERO);
446+
}
418447

419448
// TODO: assign shard_ec_sum to stark_pvs.shard_ec_sum
420449

@@ -693,6 +722,7 @@ pub fn verify_proofs(
693722

694723
let fri_params = standard_fri_params_with_100_bits_conjectured_security(1);
695724
let vb = NativeBuilder::default();
725+
696726
air_test_impl::<BabyBearPoseidon2Engine, _>(
697727
fri_params,
698728
vb,
@@ -703,14 +733,21 @@ pub fn verify_proofs(
703733
true,
704734
)
705735
.unwrap();
736+
737+
// _debug
738+
// let engine = BabyBearPoseidon2Engine::new(fri_params);
739+
// let (mut vm, pk) = VirtualMachine::new_with_keygen(engine, vb, config).expect("create vm");
740+
// let vk = pk.get_vk();
741+
// vm.verify(&vk, &proofs)
742+
// .expect("segment proofs should verify");
706743
}
707744
}
708745

709746
#[cfg(test)]
710747
mod tests {
711748
use super::verify_e2e_stark_proof;
712749
use crate::{
713-
aggregation::{CenoAggregationProver, verify_proofs},
750+
aggregation::{CenoAggregationProver, SC, verify_proofs},
714751
zkvm_verifier::binding::E,
715752
};
716753
use ceno_zkvm::{
@@ -719,6 +756,7 @@ mod tests {
719756
structs::ZKVMVerifyingKey,
720757
};
721758
use mpcs::{Basefold, BasefoldRSParams};
759+
use openvm_stark_backend::proof::Proof;
722760
use openvm_stark_sdk::{config::setup_tracing_with_log_level, p3_bn254_fr::Bn254Fr};
723761
use p3::field::FieldAlgebra;
724762
use std::fs::File;
@@ -785,6 +823,30 @@ mod tests {
785823
verify(zkvm_proofs.clone(), &verifier).expect("Verification failed");
786824
}
787825

826+
pub fn internal_aggregation_inner_thread() {
827+
setup_tracing_with_log_level(tracing::Level::WARN);
828+
829+
let vk_path = "./src/imported/vk.bin";
830+
let vk: ZKVMVerifyingKey<E, Basefold<E, BasefoldRSParams>> =
831+
bincode::deserialize_from(File::open(vk_path).expect("Failed to open vk file"))
832+
.expect("Failed to deserialize vk file");
833+
834+
let mut agg_prover = CenoAggregationProver::from_base_vk(vk);
835+
836+
// Load exported leaf proofs
837+
let leaf_proof_0: Proof<SC> = bincode::deserialize_from(
838+
File::open("./leaf_proof_0.bin").expect("Failed to open leaf_proof_0.bin"),
839+
)
840+
.expect("Failed to deserialize leaf_proof_0");
841+
let leaf_proof_1: Proof<SC> = bincode::deserialize_from(
842+
File::open("./leaf_proof_1.bin").expect("Failed to open leaf_proof_1.bin"),
843+
)
844+
.expect("Failed to deserialize leaf_proof_1");
845+
846+
let leaf_proofs = vec![leaf_proof_0, leaf_proof_1];
847+
let _root_proof = agg_prover.aggregate_internal_proofs(leaf_proofs);
848+
}
849+
788850
#[test]
789851
#[ignore = "need to generate proof first"]
790852
pub fn test_aggregation() {
@@ -798,6 +860,19 @@ mod tests {
798860
handler.join().expect("Thread panicked");
799861
}
800862

863+
#[test]
864+
#[ignore = "need to generate proof first"]
865+
pub fn test_internal_aggregation() {
866+
let stack_size = 256 * 1024 * 1024;
867+
868+
let handler = std::thread::Builder::new()
869+
.stack_size(stack_size)
870+
.spawn(internal_aggregation_inner_thread)
871+
.expect("Failed to spawn thread");
872+
873+
handler.join().expect("Thread panicked");
874+
}
875+
801876
#[test]
802877
#[ignore = "need to generate proof first"]
803878
pub fn test_single() {

ceno_recursion/src/arithmetics/mod.rs

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ use multilinear_extensions::{Expression, Fixed, Instance};
99
use openvm_native_circuit::EXT_DEG;
1010
use openvm_native_compiler::prelude::*;
1111
use openvm_native_compiler_derive::iter_zip;
12-
use openvm_native_recursion::challenger::{FeltChallenger, duplex::DuplexChallengerVariable};
12+
use openvm_native_recursion::{
13+
challenger::{FeltChallenger, duplex::DuplexChallengerVariable},
14+
vars::HintSlice,
15+
};
1316
use openvm_stark_backend::p3_field::{FieldAlgebra, FieldExtensionAlgebra};
1417

1518
type E = BabyBearExt4;
@@ -64,8 +67,41 @@ pub fn challenger_multi_observe<C: Config>(
6467
challenger: &mut DuplexChallengerVariable<C>,
6568
arr: &Array<C, Felt<C::F>>,
6669
) {
67-
let next_input_ptr =
68-
builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, arr);
70+
let next_input_ptr = builder.poseidon2_multi_observe(
71+
&challenger.sponge_state,
72+
challenger.input_ptr,
73+
arr,
74+
arr.len(),
75+
None,
76+
);
77+
builder.assign(
78+
&challenger.input_ptr,
79+
challenger.io_empty_ptr + next_input_ptr.clone(),
80+
);
81+
builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else(
82+
|builder| {
83+
builder.assign(&challenger.output_ptr, challenger.io_empty_ptr);
84+
},
85+
|builder| {
86+
builder.assign(&challenger.output_ptr, challenger.io_full_ptr);
87+
},
88+
);
89+
}
90+
91+
pub fn challenger_hint_observe<C: Config>(
92+
builder: &mut Builder<C>,
93+
challenger: &mut DuplexChallengerVariable<C>,
94+
hint_slice: &HintSlice<C>,
95+
) {
96+
let dummy_arr: Array<C, Felt<C::F>> = builder.dyn_array(0);
97+
let felt_len: Usize<C::N> = builder.eval(hint_slice.length.clone() * Usize::from(C::EF::D));
98+
let next_input_ptr = builder.poseidon2_multi_observe(
99+
&challenger.sponge_state,
100+
challenger.input_ptr,
101+
&dummy_arr,
102+
felt_len,
103+
Some(hint_slice.id.get_var()),
104+
);
69105
builder.assign(
70106
&challenger.input_ptr,
71107
challenger.io_empty_ptr + next_input_ptr.clone(),
@@ -98,18 +134,6 @@ pub fn is_smaller_than<C: Config>(
98134
RVar::from(v)
99135
}
100136

101-
pub fn evaluate_at_point_degree_1<C: Config>(
102-
builder: &mut Builder<C>,
103-
evals: &Array<C, Ext<C::F, C::EF>>,
104-
point: &Array<C, Ext<C::F, C::EF>>,
105-
) -> Ext<C::F, C::EF> {
106-
let left = builder.get(evals, 0);
107-
let right = builder.get(evals, 1);
108-
let r = builder.get(point, 0);
109-
110-
builder.eval(r * (right - left) + left)
111-
}
112-
113137
pub fn _fixed_dot_product<C: Config>(
114138
builder: &mut Builder<C>,
115139
a: &[Ext<C::F, C::EF>],
@@ -329,24 +353,6 @@ pub fn eq_eval_with_index<C: Config>(
329353
acc
330354
}
331355

332-
// Multiply all elements in a nested Array
333-
pub fn nested_product<C: Config>(
334-
builder: &mut Builder<C>,
335-
arr: &Array<C, Array<C, Ext<C::F, C::EF>>>,
336-
) -> Ext<C::F, C::EF> {
337-
let acc = builder.constant(C::EF::ONE);
338-
iter_zip!(builder, arr).for_each(|ptr_vec, builder| {
339-
let inner_arr = builder.iter_ptr_get(arr, ptr_vec[0]);
340-
341-
iter_zip!(builder, inner_arr).for_each(|ptr_vec, builder| {
342-
let el = builder.iter_ptr_get(&inner_arr, ptr_vec[0]);
343-
builder.assign(&acc, acc * el);
344-
});
345-
});
346-
347-
acc
348-
}
349-
350356
// Multiply all elements in an Array
351357
pub fn arr_product<C: Config>(
352358
builder: &mut Builder<C>,

0 commit comments

Comments
 (0)