Skip to content

Commit b6740b7

Browse files
committed
feat: interleaving indices
1 parent cfdf12e commit b6740b7

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

ditec_wdn_dataset/core/datasets_large.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,7 @@ def compute_indices(
624624
# total_num_samples: int = sum_of_root_sizes # if self.num_records is None else min(sum_of_root_sizes, self.num_records)
625625
# num_samples_per_network = total_num_samples // len(self._roots)
626626

627+
num_networks = len(self._roots)
627628
for network_index, root in enumerate(self._roots):
628629
if self.batch_axis_choice == "scene":
629630
# arr WILL have shape <merged>(#scenes, #nodes_or_#links, #statics + time_dims * #dynamics)
@@ -645,7 +646,8 @@ def compute_indices(
645646
else:
646647
raise NotImplementedError
647648
extended_network_ids = np.full([num_samples], network_index)
648-
flatten_ids = np.arange(flatten_index, flatten_index + num_samples)
649+
# flatten_ids = np.arange(flatten_index, flatten_index + num_samples)
650+
flatten_ids = np.arange(num_samples) * num_networks + network_index
649651

650652
local_chunk_map: dict[int, int] = {}
651653
lefts: np.ndarray | None = tuples[0] if tuples[0] is not None else None
@@ -699,6 +701,14 @@ def compute_indices(
699701
flatten_index += num_samples
700702
num_samples_per_network_list.append(num_samples)
701703

704+
# trick to perform interleaving, we sort the index map. The result will be
705+
# 0 -> sample_0_dataset_0
706+
# 1 -> sample_0_dataset_1
707+
# N-1 -> sample_0_dataset_N-1
708+
# N -> sample_1_dataset_0
709+
# ...
710+
index_map = OrderedDict(sorted(index_map.items()))
711+
702712
length = flatten_index
703713
return length, index_map, network_map, chunk_map, num_samples_per_network_list
704714

ditec_wdn_dataset/hf/dataset.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ def compute_indices(self, wdn_names: list[str]) -> tuple[int, dict[int, tuple[in
562562
f"ERROR! Simulation time (duration) must be equal in option batch_axis_choice <{self.batch_axis_choice}>, but get {time_dims}!"
563563
)
564564

565+
num_networks = len(self._roots)
565566
for network_index, root in enumerate(self._roots):
566567
if self.batch_axis_choice == "scene":
567568
# arr WILL have shape <merged>(#scenes, #nodes_or_#links, #statics + time_dims * #dynamics)
@@ -584,7 +585,8 @@ def compute_indices(self, wdn_names: list[str]) -> tuple[int, dict[int, tuple[in
584585
else:
585586
raise NotImplementedError
586587
extended_network_ids = np.full([num_samples], network_index)
587-
flatten_ids = np.arange(flatten_index, flatten_index + num_samples)
588+
# flatten_ids = np.arange(flatten_index, flatten_index + num_samples)
589+
flatten_ids = np.arange(num_samples) * num_networks + network_index
588590

589591
network_index_map: dict[int, tuple[int | None, int | None]] = {}
590592
# fid_nid_map: dict[int, int] = {}
@@ -601,7 +603,13 @@ def compute_indices(self, wdn_names: list[str]) -> tuple[int, dict[int, tuple[in
601603
# update flatten index indicator and network index
602604
flatten_index += num_samples
603605
num_samples_per_network_list.append(num_samples)
604-
606+
# trick to perform interleaving, we sort the index map. The result will be
607+
# 0 -> sample_0_dataset_0
608+
# 1 -> sample_0_dataset_1
609+
# N-1 -> sample_0_dataset_N-1
610+
# N -> sample_1_dataset_0
611+
# ...
612+
index_map = OrderedDict(sorted(index_map.items()))
605613
length = flatten_index
606614
return length, index_map, network_map, num_samples_per_network_list
607615

0 commit comments

Comments
 (0)