@@ -7,6 +7,7 @@ use crate::{
77 hal:: ProverDevice ,
88 mock_prover:: { LkMultiplicityKey , MockProver } ,
99 prover:: ZKVMProver ,
10+ septic_curve:: SepticPoint ,
1011 verifier:: ZKVMVerifier ,
1112 } ,
1213 state:: GlobalState ,
@@ -44,6 +45,7 @@ use witness::next_pow2_instance_padding;
4445
4546pub const DEFAULT_MIN_CYCLE_PER_SHARDS : Cycle = 1 << 24 ;
4647pub const DEFAULT_MAX_CYCLE_PER_SHARDS : Cycle = 1 << 27 ;
48+ pub const DEFAULT_CROSS_SHARD_ACCESS_LIMIT : usize = 1 << 20 ;
4749
4850/// The polynomial commitment scheme kind
4951#[ derive(
@@ -175,11 +177,16 @@ pub struct ShardContext<'a> {
175177 Either < Vec < BTreeMap < WordAddr , RAMRecord > > , & ' a mut BTreeMap < WordAddr , RAMRecord > > ,
176178 pub cur_shard_cycle_range : std:: ops:: Range < usize > ,
177179 pub expected_inst_per_shard : usize ,
180+ pub max_num_cross_shard_accesses : usize ,
178181}
179182
180183impl < ' a > Default for ShardContext < ' a > {
181184 fn default ( ) -> Self {
182185 let max_threads = max_usable_threads ( ) ;
186+ let max_num_cross_shard_accesses = std:: env:: var ( "CENO_CROSS_SHARD_LIMIT" )
187+ . map ( |v| v. parse ( ) . unwrap_or ( DEFAULT_CROSS_SHARD_ACCESS_LIMIT ) )
188+ . unwrap_or ( DEFAULT_CROSS_SHARD_ACCESS_LIMIT ) ;
189+
183190 Self {
184191 shard_id : 0 ,
185192 num_shards : 1 ,
@@ -202,6 +209,7 @@ impl<'a> Default for ShardContext<'a> {
202209 ) ,
203210 cur_shard_cycle_range : Tracer :: SUBCYCLES_PER_INSN as usize ..usize:: MAX ,
204211 expected_inst_per_shard : usize:: MAX ,
212+ max_num_cross_shard_accesses,
205213 }
206214 }
207215}
@@ -231,6 +239,10 @@ impl<'a> ShardContext<'a> {
231239 let subcycle_per_insn = Tracer :: SUBCYCLES_PER_INSN as usize ;
232240 let max_threads = max_usable_threads ( ) ;
233241
242+ let max_num_cross_shard_accesses = std:: env:: var ( "CENO_CROSS_SHARD_LIMIT" )
243+ . map ( |v| v. parse ( ) . unwrap_or ( DEFAULT_CROSS_SHARD_ACCESS_LIMIT ) )
244+ . unwrap_or ( DEFAULT_CROSS_SHARD_ACCESS_LIMIT ) ;
245+
234246 // strategies
235247 // 0. set cur_num_shards = num_provers
236248 // 1. split instructions evenly by cur_num_shards
@@ -323,6 +335,7 @@ impl<'a> ShardContext<'a> {
323335 ) ,
324336 cur_shard_cycle_range,
325337 expected_inst_per_shard,
338+ max_num_cross_shard_accesses,
326339 }
327340 } )
328341 . collect_vec ( )
@@ -355,6 +368,7 @@ impl<'a> ShardContext<'a> {
355368 write_records_tbs : Either :: Right ( write) ,
356369 cur_shard_cycle_range : self . cur_shard_cycle_range . clone ( ) ,
357370 expected_inst_per_shard : self . expected_inst_per_shard ,
371+ max_num_cross_shard_accesses : self . max_num_cross_shard_accesses ,
358372 } ,
359373 )
360374 . collect_vec ( ) ,
@@ -1125,17 +1139,26 @@ pub fn generate_witness<'a, E: ExtensionField>(
11251139 pi. end_pc = current_shard_end_pc;
11261140 pi. end_cycle = current_shard_end_cycle;
11271141 // set shard ram bus expected output to pi
1128- let shard_ram_witness = zkvm_witness. get_table_witness ( & ShardRamCircuit :: < E > :: name ( ) ) ;
1129- if let Some ( shard_ram_witness) = shard_ram_witness
1130- && shard_ram_witness[ 0 ] . num_instances ( ) > 0
1131- {
1132- for ( f, v) in ShardRamCircuit :: < E > :: extract_ec_sum (
1133- & system_config. mmu_config . ram_bus_circuit ,
1134- & shard_ram_witness[ 0 ] ,
1135- )
1136- . into_iter ( )
1137- . zip_eq ( pi. shard_rw_sum . as_mut_slice ( ) )
1138- {
1142+ let shard_ram_witnesses = zkvm_witness. get_witness ( & ShardRamCircuit :: < E > :: name ( ) ) ;
1143+
1144+ if let Some ( shard_ram_witnesses) = shard_ram_witnesses {
1145+ let shard_ram_ec_sum: SepticPoint < E :: BaseField > = shard_ram_witnesses
1146+ . iter ( )
1147+ . filter ( |shard_ram_witness| shard_ram_witness. num_instances [ 0 ] > 0 )
1148+ . map ( |shard_ram_witness| {
1149+ ShardRamCircuit :: < E > :: extract_ec_sum (
1150+ & system_config. mmu_config . ram_bus_circuit ,
1151+ & shard_ram_witness. witness_rmms [ 0 ] ,
1152+ )
1153+ } )
1154+ . sum ( ) ;
1155+
1156+ let xy = shard_ram_ec_sum
1157+ . x
1158+ . 0
1159+ . iter ( )
1160+ . chain ( shard_ram_ec_sum. y . 0 . iter ( ) ) ;
1161+ for ( f, v) in xy. zip_eq ( pi. shard_rw_sum . as_mut_slice ( ) ) {
11391162 * v = f. to_canonical_u64 ( ) as u32 ;
11401163 }
11411164 }
0 commit comments