Skip to content

Commit 006b1e6

Browse files
committed
fix: hot fix the if bug
1 parent 367d667 commit 006b1e6

File tree

2 files changed

+140
-184
lines changed

2 files changed

+140
-184
lines changed

ditec_wdn_dataset/core/datasets_large.py

Lines changed: 71 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def __init__(
186186
verbose: bool = True,
187187
split_type: Literal["temporal", "scene"] = "scene",
188188
split_set: Literal["train", "val", "test", "all"] = "all",
189-
split_fixed: bool = True,
190189
split_per_network: bool = True,
191190
skip_nodes_list: list[list[str]] = [],
192191
skip_types_list: list[list[str]] = [],
@@ -216,7 +215,6 @@ def __init__(
216215
verbose (bool, optional): flag indicates whether logging debug info. Defaults to True.
217216
split_type (Literal["temporal", "scene"], optional): Deprecated. Defaults to "scene".
218217
split_set (Literal["train", "val", "test", "all"], optional): to split subsets. Defaults to "all".
219-
split_fixed (bool, optional): flag indicates whether the cutting positions are pre-defined (fixed) or dynamic
220218
split_per_network (bool, optional): If True, foreach network, we split train, valid, test individually (Useful for multiple network joint-training). Otherwise, we concatenate all networks into a single to-be-splitted array
221219
skip_nodes_list (list[list[str]], optional): add extra skipped node names, otherwise, load skip_nodes in zarr zip files. Defaults to [].
222220
skip_types_list (list[list[str]], optional): massively skip by node types (juctions, reservoir, tank). Defaults to [].
@@ -251,7 +249,6 @@ def __init__(
251249
self.skip_types_list = skip_types_list
252250
self.split_type: Literal["temporal", "scene"] = split_type
253251
self.split_ratios: tuple[float, float, float] = (0.6, 0.2, 0.2)
254-
self.split_fixed = split_fixed
255252
self.split_per_network = split_per_network
256253
self._arrays: list[tuple[zarr.Array, zarr.Array, zarr.Array | None, zarr.Array | None, zarr.Array | None]] = []
257254
self._index_map: dict[int, tuple[int | None, int | None]] = {}
@@ -281,10 +278,7 @@ def __init__(
281278
self.update_indices()
282279

283280
def _get_num_samples(self) -> int:
284-
if self.split_fixed:
285-
return self.num_records if self.num_records is not None else self.length
286-
else:
287-
return self.length
281+
return self.length
288282

289283
def custom_process(self) -> None:
290284
# load arrays from zip file (and input_paths)
@@ -446,76 +440,60 @@ def compute_subset_ids_by_ratio(self, split_ratios: tuple[float, float, float],
446440
train_ids, val_ids, test_ids = [], [], []
447441
len_of_list = len(self._num_samples_per_network_list)
448442
# if not split per network or existing a single network only, we split based on flatten ids
449-
if not self.split_per_network or len_of_list == 1:
450-
left = int(self.length * split_ratios[0])
451-
right = int(left + self.length * split_ratios[1])
452-
453-
flatten_ids = np.asarray(list(self._index_map.keys()))
454-
455-
flatten_ids = flatten_ids.tolist()
456-
train_ids = flatten_ids[:left]
457-
val_ids = flatten_ids[left:right]
458-
test_ids = flatten_ids[right:]
459-
if self.split_fixed:
460-
selected_trains = int(num_samples * split_ratios[0])
461-
selected_vals = int(num_samples * split_ratios[1])
462-
selected_test = num_samples - selected_trains - selected_vals
463-
464-
train_ids = train_ids[:selected_trains]
465-
val_ids = val_ids[:selected_vals]
466-
test_ids = test_ids[:selected_test]
467-
else:
468-
# to split per network, we compute train/val/test individually
469-
# degree of freedom will be (len_of_list - 1)
470-
expected_train_samples = int(self.length * split_ratios[0])
471-
expected_valid_samples = int(self.length * split_ratios[1])
472-
expected_test_samples = self.length - expected_train_samples - expected_valid_samples
473-
flatten_ids = np.asarray(list(self._index_map.keys()))
474-
475-
num_samples_per_network = num_samples // len(self._num_samples_per_network_list)
476-
current_nid = 0
477-
for i, network_num_samples in enumerate(self._num_samples_per_network_list):
478-
network_flatten_ids = flatten_ids[current_nid : current_nid + network_num_samples]
479-
480-
if self.batch_axis_choice == "snapshot":
481-
# with snapshots, we still split by scence to ensure the scenario independence
482-
# f_0-> (n_0, t_0), f_1 -> (n_0, t_1), ..., f_T -> (n_0, t_T), f_T+1 -> (n_1, t_0), ...
483-
time_dim = self._roots[i].time_dim
484-
num_scenes = len(network_flatten_ids) // time_dim
485-
left = int(num_scenes * split_ratios[0])
486-
right = int(left + num_scenes * split_ratios[1])
487-
488-
left = left * time_dim
489-
right = right * time_dim
490-
491-
network_train_ids = network_flatten_ids[:left]
492-
network_val_ids = network_flatten_ids[left:right]
493-
network_test_ids = network_flatten_ids[right:]
443+
# if not self.split_per_network or len_of_list == 1:
444+
# left = int(self.length * split_ratios[0])
445+
# right = int(left + self.length * split_ratios[1])
446+
447+
# flatten_ids = np.asarray(list(self._index_map.keys()))
448+
449+
# flatten_ids = flatten_ids.tolist()
450+
# train_ids = flatten_ids[:left]
451+
# val_ids = flatten_ids[left:right]
452+
# test_ids = flatten_ids[right:]
453+
# else:
454+
# to split per network, we compute train/val/test individually
455+
# degree of freedom will be (len_of_list - 1)
456+
expected_train_samples = int(self.length * split_ratios[0])
457+
expected_valid_samples = int(self.length * split_ratios[1])
458+
expected_test_samples = self.length - expected_train_samples - expected_valid_samples
459+
flatten_ids = np.asarray(list(self._index_map.keys()))
460+
461+
num_samples_per_network = num_samples // len(self._num_samples_per_network_list)
462+
current_nid = 0
463+
for i, network_num_samples in enumerate(self._num_samples_per_network_list):
464+
network_flatten_ids = flatten_ids[current_nid : current_nid + network_num_samples]
465+
466+
if self.batch_axis_choice == "snapshot":
467+
# with snapshots, we still split by scence to ensure the scenario independence
468+
# f_0-> (n_0, t_0), f_1 -> (n_0, t_1), ..., f_T -> (n_0, t_T), f_T+1 -> (n_1, t_0), ...
469+
time_dim = self._roots[i].time_dim
470+
num_scenes = len(network_flatten_ids) // time_dim
471+
left = int(num_scenes * split_ratios[0])
472+
right = int(left + num_scenes * split_ratios[1])
473+
474+
left = left * time_dim
475+
right = right * time_dim
476+
477+
network_train_ids = network_flatten_ids[:left]
478+
network_val_ids = network_flatten_ids[left:right]
479+
network_test_ids = network_flatten_ids[right:]
494480

495-
else:
496-
left = int(network_num_samples * split_ratios[0])
497-
right = int(left + network_num_samples * split_ratios[1])
498-
network_train_ids = network_flatten_ids[:left]
499-
network_val_ids = network_flatten_ids[left:right]
500-
network_test_ids = network_flatten_ids[right:]
501-
502-
if i == len_of_list - 1:
503-
network_train_ids = network_train_ids[: expected_train_samples - len(train_ids)]
504-
network_val_ids = network_val_ids[: expected_valid_samples - len(val_ids)]
505-
network_test_ids = network_test_ids[: expected_test_samples - len(test_ids)]
506-
507-
if self.split_fixed:
508-
selected_trains = int(num_samples_per_network * split_ratios[0])
509-
selected_vals = int(num_samples_per_network * split_ratios[1])
510-
selected_test = num_samples_per_network - selected_trains - selected_vals
511-
network_train_ids = network_train_ids[:selected_trains]
512-
network_val_ids = network_val_ids[:selected_vals]
513-
network_test_ids = network_test_ids[:selected_test]
514-
515-
train_ids.extend(network_train_ids.tolist())
516-
val_ids.extend(network_val_ids.tolist())
517-
test_ids.extend(network_test_ids.tolist())
518-
current_nid += network_num_samples
481+
else:
482+
left = int(network_num_samples * split_ratios[0])
483+
right = int(left + network_num_samples * split_ratios[1])
484+
network_train_ids = network_flatten_ids[:left]
485+
network_val_ids = network_flatten_ids[left:right]
486+
network_test_ids = network_flatten_ids[right:]
487+
488+
if i == len_of_list - 1:
489+
network_train_ids = network_train_ids[: expected_train_samples - len(train_ids)]
490+
network_val_ids = network_val_ids[: expected_valid_samples - len(val_ids)]
491+
network_test_ids = network_test_ids[: expected_test_samples - len(test_ids)]
492+
493+
train_ids.extend(network_train_ids.tolist())
494+
val_ids.extend(network_val_ids.tolist())
495+
test_ids.extend(network_test_ids.tolist())
496+
current_nid += network_num_samples
519497

520498
return train_ids, val_ids, test_ids
521499

@@ -529,54 +507,32 @@ def compute_indices(
529507
flatten_index = 0
530508
self.load_roots(zip_file_paths, input_paths)
531509

532-
root_sizes: list[int] = [r.compute_first_size() for r in self._roots]
533-
if self.batch_axis_choice == "scene":
534-
sum_of_root_sizes = sum(root_sizes)
535-
elif self.batch_axis_choice == "temporal":
536-
sum_of_root_sizes = sum([r.time_dim for r in self._roots])
537-
else: # snapshot
538-
sum_of_root_sizes = sum([r.time_dim * root_sizes[i] for i, r in enumerate(self._roots)])
539-
540-
total_num_samples: int = sum_of_root_sizes # if self.num_records is None else min(sum_of_root_sizes, self.num_records)
541-
num_samples_per_network = total_num_samples // len(self._roots)
510+
# root_sizes: list[int] = [r.compute_first_size() for r in self._roots]
511+
# if self.batch_axis_choice == "scene":
512+
# sum_of_root_sizes = sum(root_sizes)
513+
# elif self.batch_axis_choice == "temporal":
514+
# sum_of_root_sizes = sum([r.time_dim for r in self._roots])
515+
# else: # snapshot
516+
# sum_of_root_sizes = sum([r.time_dim * root_sizes[i] for i, r in enumerate(self._roots)])
542517

543518
for network_index, root in enumerate(self._roots):
544519
if self.batch_axis_choice == "scene":
545520
# arr WILL have shape <merged>(#scenes, #nodes_or_#links, #statics + time_dims * #dynamics)
546-
if self.split_fixed:
547-
relative_scene_ids = np.arange(root_sizes[network_index])
548-
relative_scene_ids = relative_scene_ids[:num_samples_per_network]
549-
num_samples = len(relative_scene_ids) # min(num_samples_per_network, root_sizes[network_index])
550-
else:
551-
num_samples = root.compute_first_size() #if self.num_records is None else min(self.num_records, root.compute_first_size())
552-
relative_scene_ids = np.arange(num_samples)
521+
num_samples = root.compute_first_size() # if self.num_records is None else min(self.num_records, root.compute_first_size())
522+
relative_scene_ids = np.arange(num_samples)
553523
tuples = (relative_scene_ids, None)
554524
elif self.batch_axis_choice == "temporal":
555-
if self.split_fixed:
556-
relative_time_ids = np.arange(root.time_dim)
557-
relative_time_ids = relative_time_ids[:num_samples_per_network]
558-
num_samples = len(relative_time_ids)
559-
else:
560-
num_samples = root.time_dim
561-
relative_time_ids = np.arange(num_samples)
525+
num_samples = root.time_dim
526+
relative_time_ids = np.arange(num_samples)
562527
tuples = (None, relative_time_ids)
563528

564529
elif self.batch_axis_choice == "snapshot":
565-
if self.split_fixed:
566-
time_dim = root.time_dim
567-
num_scenes = root_sizes[network_index]
568-
relative_scene_ids = np.arange(num_scenes).repeat(time_dim) # .reshape([-1, 1])
569-
relative_time_ids = np.tile(np.arange(time_dim), reps=num_scenes) # .reshape([-1, 1])
570-
relative_scene_ids = relative_scene_ids[:num_samples_per_network]
571-
relative_time_ids = relative_time_ids[:num_samples_per_network]
572-
num_samples = len(relative_scene_ids)
573-
else:
574-
num_scenes = root.compute_first_size() #if self.num_records is None else min(self.num_records, root.compute_first_size())
575-
time_dim = root.time_dim
576-
relative_scene_ids = np.arange(num_scenes).repeat(time_dim) # .reshape([-1, 1])
577-
relative_time_ids = np.tile(np.arange(time_dim), reps=num_scenes) # .reshape([-1, 1])
578-
tuples = (relative_scene_ids, relative_time_ids)
579-
num_samples = len(relative_scene_ids)
530+
num_scenes = root.compute_first_size() # if self.num_records is None else min(self.num_records, root.compute_first_size())
531+
time_dim = root.time_dim
532+
relative_scene_ids = np.arange(num_scenes).repeat(time_dim) # .reshape([-1, 1])
533+
relative_time_ids = np.tile(np.arange(time_dim), reps=num_scenes) # .reshape([-1, 1])
534+
tuples = (relative_scene_ids, relative_time_ids)
535+
num_samples = len(relative_scene_ids)
580536
tuples = (relative_scene_ids, relative_time_ids)
581537
else:
582538
raise NotImplementedError

0 commit comments

Comments
 (0)