@@ -562,6 +562,7 @@ def compute_indices(self, wdn_names: list[str]) -> tuple[int, dict[int, tuple[in
562562 f"ERROR! Simulation time (duration) must be equal in option batch_axis_choice <{ self .batch_axis_choice } >, but get { time_dims } !"
563563 )
564564
565+ num_networks = len (self ._roots )
565566 for network_index , root in enumerate (self ._roots ):
566567 if self .batch_axis_choice == "scene" :
567568 # arr WILL have shape <merged>(#scenes, #nodes_or_#links, #statics + time_dims * #dynamics)
@@ -584,7 +585,8 @@ def compute_indices(self, wdn_names: list[str]) -> tuple[int, dict[int, tuple[in
584585 else :
585586 raise NotImplementedError
586587 extended_network_ids = np .full ([num_samples ], network_index )
587- flatten_ids = np .arange (flatten_index , flatten_index + num_samples )
588+ # flatten_ids = np.arange(flatten_index, flatten_index + num_samples)
589+ flatten_ids = np .arange (num_samples ) * num_networks + network_index
588590
589591 network_index_map : dict [int , tuple [int | None , int | None ]] = {}
590592 # fid_nid_map: dict[int, int] = {}
@@ -601,7 +603,13 @@ def compute_indices(self, wdn_names: list[str]) -> tuple[int, dict[int, tuple[in
601603 # update flatten index indicator and network index
602604 flatten_index += num_samples
603605 num_samples_per_network_list .append (num_samples )
604-
606+ # trick to perform interleaving, we sort the index map. The result will be
607+ # 0 -> sample_0_dataset_0
608+ # 1 -> sample_0_dataset_1
609+ # N-1 -> sample_0_dataset_N-1
610+ # N -> sample_1_dataset_0
611+ # ...
612+ index_map = OrderedDict (sorted (index_map .items ()))
605613 length = flatten_index
606614 return length , index_map , network_map , num_samples_per_network_list
607615
0 commit comments