Skip to content

Commit 07fc421

Browse files
feat: multiple instances of a circuit in single shard (#1126)
# Summary We want to prove multiple instances of a circuit in a single shard. For example, in the case of `ShardRam` circuit, the 0-th shard often has ~2M records which is too large to fit in GPU memory. - [ ] #1130 --------- Co-authored-by: sm.wu <[email protected]>
1 parent b2233fd commit 07fc421

File tree

12 files changed

+349
-298
lines changed

12 files changed

+349
-298
lines changed

.github/workflows/integration.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ jobs:
141141
RUSTFLAGS: "-C opt-level=3"
142142
run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/bn254_curve_syscalls
143143

144+
- name: Run fibonacci (release) in 3 shards with CENO_CROSS_SHARD_LIMIT
145+
env:
146+
RUST_LOG: debug
147+
RUSTFLAGS: "-C opt-level=3"
148+
CENO_CROSS_SHARD_LIMIT: 32
149+
run: cargo run --release --package ceno_zkvm --features sanity-check --bin e2e -- --platform=ceno --min-cycle-per-shard=10 --max-cycle-per-shard=20000 --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci
150+
144151
- name: Install cargo make
145152
run: |
146153
cargo make --version || cargo install cargo-make

ceno_zkvm/src/e2e.rs

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::{
77
hal::ProverDevice,
88
mock_prover::{LkMultiplicityKey, MockProver},
99
prover::ZKVMProver,
10+
septic_curve::SepticPoint,
1011
verifier::ZKVMVerifier,
1112
},
1213
state::GlobalState,
@@ -44,6 +45,7 @@ use witness::next_pow2_instance_padding;
4445

4546
pub const DEFAULT_MIN_CYCLE_PER_SHARDS: Cycle = 1 << 24;
4647
pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 27;
48+
pub const DEFAULT_CROSS_SHARD_ACCESS_LIMIT: usize = 1 << 20;
4749

4850
/// The polynomial commitment scheme kind
4951
#[derive(
@@ -175,11 +177,16 @@ pub struct ShardContext<'a> {
175177
Either<Vec<BTreeMap<WordAddr, RAMRecord>>, &'a mut BTreeMap<WordAddr, RAMRecord>>,
176178
pub cur_shard_cycle_range: std::ops::Range<usize>,
177179
pub expected_inst_per_shard: usize,
180+
pub max_num_cross_shard_accesses: usize,
178181
}
179182

180183
impl<'a> Default for ShardContext<'a> {
181184
fn default() -> Self {
182185
let max_threads = max_usable_threads();
186+
let max_num_cross_shard_accesses = std::env::var("CENO_CROSS_SHARD_LIMIT")
187+
.map(|v| v.parse().unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT))
188+
.unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT);
189+
183190
Self {
184191
shard_id: 0,
185192
num_shards: 1,
@@ -202,6 +209,7 @@ impl<'a> Default for ShardContext<'a> {
202209
),
203210
cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX,
204211
expected_inst_per_shard: usize::MAX,
212+
max_num_cross_shard_accesses,
205213
}
206214
}
207215
}
@@ -231,6 +239,10 @@ impl<'a> ShardContext<'a> {
231239
let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize;
232240
let max_threads = max_usable_threads();
233241

242+
let max_num_cross_shard_accesses = std::env::var("CENO_CROSS_SHARD_LIMIT")
243+
.map(|v| v.parse().unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT))
244+
.unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT);
245+
234246
// strategies
235247
// 0. set cur_num_shards = num_provers
236248
// 1. split instructions evenly by cur_num_shards
@@ -323,6 +335,7 @@ impl<'a> ShardContext<'a> {
323335
),
324336
cur_shard_cycle_range,
325337
expected_inst_per_shard,
338+
max_num_cross_shard_accesses,
326339
}
327340
})
328341
.collect_vec()
@@ -355,6 +368,7 @@ impl<'a> ShardContext<'a> {
355368
write_records_tbs: Either::Right(write),
356369
cur_shard_cycle_range: self.cur_shard_cycle_range.clone(),
357370
expected_inst_per_shard: self.expected_inst_per_shard,
371+
max_num_cross_shard_accesses: self.max_num_cross_shard_accesses,
358372
},
359373
)
360374
.collect_vec(),
@@ -1125,17 +1139,26 @@ pub fn generate_witness<'a, E: ExtensionField>(
11251139
pi.end_pc = current_shard_end_pc;
11261140
pi.end_cycle = current_shard_end_cycle;
11271141
// set shard ram bus expected output to pi
1128-
let shard_ram_witness = zkvm_witness.get_table_witness(&ShardRamCircuit::<E>::name());
1129-
if let Some(shard_ram_witness) = shard_ram_witness
1130-
&& shard_ram_witness[0].num_instances() > 0
1131-
{
1132-
for (f, v) in ShardRamCircuit::<E>::extract_ec_sum(
1133-
&system_config.mmu_config.ram_bus_circuit,
1134-
&shard_ram_witness[0],
1135-
)
1136-
.into_iter()
1137-
.zip_eq(pi.shard_rw_sum.as_mut_slice())
1138-
{
1142+
let shard_ram_witnesses = zkvm_witness.get_witness(&ShardRamCircuit::<E>::name());
1143+
1144+
if let Some(shard_ram_witnesses) = shard_ram_witnesses {
1145+
let shard_ram_ec_sum: SepticPoint<E::BaseField> = shard_ram_witnesses
1146+
.iter()
1147+
.filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0)
1148+
.map(|shard_ram_witness| {
1149+
ShardRamCircuit::<E>::extract_ec_sum(
1150+
&system_config.mmu_config.ram_bus_circuit,
1151+
&shard_ram_witness.witness_rmms[0],
1152+
)
1153+
})
1154+
.sum();
1155+
1156+
let xy = shard_ram_ec_sum
1157+
.x
1158+
.0
1159+
.iter()
1160+
.chain(shard_ram_ec_sum.y.0.iter());
1161+
for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) {
11391162
*v = f.to_canonical_u64() as u32;
11401163
}
11411164
}

ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ impl<E: ExtensionField> MmuConfig<'_, E> {
196196
&self.local_final_circuit,
197197
&(shard_ctx, all_records.as_slice()),
198198
)?;
199-
witness.assign_global_chip_circuit(
199+
witness.assign_shared_circuit(
200200
cs,
201201
&(shard_ctx, all_records.as_slice()),
202202
&self.ram_bus_circuit,

ceno_zkvm/src/keygen.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ impl<E: ExtensionField> ZKVMConstraintSystem<E> {
3939
fixed_traces.insert(circuit_index, fixed_trace_rmm);
4040
}
4141

42+
vm_pk
43+
.circuit_name_to_index
44+
.insert(c_name.clone(), circuit_index);
4245
let circuit_pk = cs.key_gen();
4346
assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none());
4447
}

ceno_zkvm/src/scheme.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
88
use std::{
99
collections::{BTreeMap, HashMap},
1010
fmt::{self, Debug},
11+
iter,
1112
ops::Div,
1213
rc::Rc,
1314
};
@@ -156,7 +157,8 @@ pub struct ZKVMProof<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
156157
pub raw_pi: Vec<Vec<E::BaseField>>,
157158
// the evaluation of raw_pi.
158159
pub pi_evals: Vec<E>,
159-
pub chip_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
160+
// each circuit may have multiple proof instances
161+
pub chip_proofs: BTreeMap<usize, Vec<ZKVMChipProof<E>>>,
160162
pub witin_commit: <PCS as PolynomialCommitmentScheme<E>>::Commitment,
161163
pub opening_proof: PCS::Proof,
162164
}
@@ -165,7 +167,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
165167
pub fn new(
166168
raw_pi: Vec<Vec<E::BaseField>>,
167169
pi_evals: Vec<E>,
168-
chip_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
170+
chip_proofs: BTreeMap<usize, Vec<ZKVMChipProof<E>>>,
169171
witin_commit: <PCS as PolynomialCommitmentScheme<E>>::Commitment,
170172
opening_proof: PCS::Proof,
171173
) -> Self {
@@ -211,7 +213,13 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
211213
let halt_instance_count = self
212214
.chip_proofs
213215
.get(&halt_circuit_index)
214-
.map_or(0, |proof| proof.num_instances.iter().sum());
216+
.map_or(0, |proofs| {
217+
proofs
218+
.iter()
219+
.flat_map(|proof| &proof.num_instances)
220+
.copied()
221+
.sum()
222+
});
215223
if halt_instance_count > 0 {
216224
assert_eq!(
217225
halt_instance_count, 1,
@@ -240,6 +248,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + Serialize> fmt::Dis
240248
let tower_proof = self
241249
.chip_proofs
242250
.iter()
251+
.flat_map(|(circuit_index, proofs)| {
252+
iter::repeat_n(circuit_index, proofs.len()).zip(proofs)
253+
})
243254
.map(|(circuit_index, proof)| {
244255
let size = bincode::serialized_size(&proof.tower_proof);
245256
size.inspect(|size| {
@@ -254,6 +265,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + Serialize> fmt::Dis
254265
let main_sumcheck = self
255266
.chip_proofs
256267
.iter()
268+
.flat_map(|(circuit_index, proofs)| {
269+
iter::repeat_n(circuit_index, proofs.len()).zip(proofs)
270+
})
257271
.map(|(circuit_index, proof)| {
258272
let size = bincode::serialized_size(&proof.main_sumcheck_proofs);
259273
size.inspect(|size| {

ceno_zkvm/src/scheme/mock_prover.rs

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use p3::field::{Field, FieldAlgebra};
3434
use rand::thread_rng;
3535
use std::{
3636
cmp::max,
37-
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
37+
collections::{BTreeSet, HashMap, HashSet},
3838
fmt::Debug,
3939
fs::File,
4040
hash::Hash,
@@ -1004,21 +1004,13 @@ Hints:
10041004
let mut fixed_mles = HashMap::new();
10051005
let mut num_instances = HashMap::new();
10061006

1007-
let circuit_index_fixed_num_instances: BTreeMap<String, usize> = fixed_trace
1008-
.circuit_fixed_traces
1009-
.iter()
1010-
.map(|(circuit_name, rmm)| {
1011-
(
1012-
circuit_name.clone(),
1013-
rmm.as_ref().map(|rmm| rmm.num_instances()).unwrap_or(0),
1014-
)
1015-
})
1016-
.collect();
10171007
let mut lkm_tables = LkMultiplicityRaw::<E>::default();
10181008
let mut lkm_opcodes = LkMultiplicityRaw::<E>::default();
10191009

10201010
// Process all circuits.
1021-
for (circuit_name, composed_cs) in &cs.circuit_css {
1011+
for (circuit_name, chip_inputs) in &witnesses.witnesses {
1012+
let composed_cs = cs.circuit_css.get(circuit_name).unwrap();
1013+
// for (circuit_name, composed_cs) in &cs.circuit_css {
10221014
let ComposedConstrainSystem {
10231015
zkvm_v1_css: cs, ..
10241016
} = &composed_cs;
@@ -1037,30 +1029,21 @@ Hints:
10371029
continue;
10381030
}
10391031

1040-
let [witness, structural_witness] = witnesses
1041-
.get_opcode_witness(circuit_name)
1042-
.or_else(|| witnesses.get_table_witness(circuit_name))
1043-
.unwrap_or_else(|| panic!("witness for {} should not be None", circuit_name));
1044-
let num_rows = if witness.num_instances() > 0 {
1045-
witness.num_instances()
1046-
} else if structural_witness.num_instances() > 0 {
1047-
structural_witness.num_instances()
1048-
} else if composed_cs.is_static_circuit() {
1049-
circuit_index_fixed_num_instances
1050-
.get(circuit_name)
1051-
.copied()
1052-
.unwrap_or(0)
1053-
} else {
1054-
0
1055-
};
1032+
assert!(chip_inputs.len() <= 1, "TODO support > 1 chip_inputs");
1033+
let chip_input = chip_inputs.first().filter(|ci| ci.num_instances() > 0);
10561034

1057-
if num_rows == 0 {
1035+
if chip_input.is_none() {
10581036
wit_mles.insert(circuit_name.clone(), vec![]);
10591037
structural_wit_mles.insert(circuit_name.clone(), vec![]);
10601038
fixed_mles.insert(circuit_name.clone(), vec![]);
10611039
num_instances.insert(circuit_name.clone(), 0);
10621040
continue;
10631041
}
1042+
1043+
let chip_input = chip_input.unwrap();
1044+
let num_rows = chip_input.num_instances();
1045+
1046+
let [witness, structural_witness] = &chip_input.witness_rmms;
10641047
let mut witness = witness
10651048
.to_mles()
10661049
.into_iter()

0 commit comments

Comments
 (0)