@@ -186,7 +186,6 @@ def __init__(
186186 verbose : bool = True ,
187187 split_type : Literal ["temporal" , "scene" ] = "scene" ,
188188 split_set : Literal ["train" , "val" , "test" , "all" ] = "all" ,
189- split_fixed : bool = True ,
190189 split_per_network : bool = True ,
191190 skip_nodes_list : list [list [str ]] = [],
192191 skip_types_list : list [list [str ]] = [],
@@ -216,7 +215,6 @@ def __init__(
216215 verbose (bool, optional): flag indicates whether logging debug info. Defaults to True.
217216 split_type (Literal["temporal", "scene"], optional): Deprecated. Defaults to "scene".
218217 split_set (Literal["train", "val", "test", "all"], optional): to split subsets. Defaults to "all".
219- split_fixed (bool, optional): flag indicates whether the cutting positions are pre-defined (fixed) or dynamic
220218 split_per_network (bool, optional): If True, foreach network, we split train, valid, test individually (Useful for multiple network joint-training). Otherwise, we concatenate all networks into a single to-be-splitted array
221219 skip_nodes_list (list[list[str]], optional): add extra skipped node names, otherwise, load skip_nodes in zarr zip files. Defaults to [].
222220 skip_types_list (list[list[str]], optional): massively skip by node types (juctions, reservoir, tank). Defaults to [].
@@ -251,7 +249,6 @@ def __init__(
251249 self .skip_types_list = skip_types_list
252250 self .split_type : Literal ["temporal" , "scene" ] = split_type
253251 self .split_ratios : tuple [float , float , float ] = (0.6 , 0.2 , 0.2 )
254- self .split_fixed = split_fixed
255252 self .split_per_network = split_per_network
256253 self ._arrays : list [tuple [zarr .Array , zarr .Array , zarr .Array | None , zarr .Array | None , zarr .Array | None ]] = []
257254 self ._index_map : dict [int , tuple [int | None , int | None ]] = {}
@@ -281,10 +278,7 @@ def __init__(
281278 self .update_indices ()
282279
283280 def _get_num_samples (self ) -> int :
284- if self .split_fixed :
285- return self .num_records if self .num_records is not None else self .length
286- else :
287- return self .length
281+ return self .length
288282
289283 def custom_process (self ) -> None :
290284 # load arrays from zip file (and input_paths)
@@ -446,76 +440,60 @@ def compute_subset_ids_by_ratio(self, split_ratios: tuple[float, float, float],
446440 train_ids , val_ids , test_ids = [], [], []
447441 len_of_list = len (self ._num_samples_per_network_list )
448442 # if not split per network or existing a single network only, we split based on flatten ids
449- if not self .split_per_network or len_of_list == 1 :
450- left = int (self .length * split_ratios [0 ])
451- right = int (left + self .length * split_ratios [1 ])
452-
453- flatten_ids = np .asarray (list (self ._index_map .keys ()))
454-
455- flatten_ids = flatten_ids .tolist ()
456- train_ids = flatten_ids [:left ]
457- val_ids = flatten_ids [left :right ]
458- test_ids = flatten_ids [right :]
459- if self .split_fixed :
460- selected_trains = int (num_samples * split_ratios [0 ])
461- selected_vals = int (num_samples * split_ratios [1 ])
462- selected_test = num_samples - selected_trains - selected_vals
463-
464- train_ids = train_ids [:selected_trains ]
465- val_ids = val_ids [:selected_vals ]
466- test_ids = test_ids [:selected_test ]
467- else :
468- # to split per network, we compute train/val/test individually
469- # degree of freedom will be (len_of_list - 1)
470- expected_train_samples = int (self .length * split_ratios [0 ])
471- expected_valid_samples = int (self .length * split_ratios [1 ])
472- expected_test_samples = self .length - expected_train_samples - expected_valid_samples
473- flatten_ids = np .asarray (list (self ._index_map .keys ()))
474-
475- num_samples_per_network = num_samples // len (self ._num_samples_per_network_list )
476- current_nid = 0
477- for i , network_num_samples in enumerate (self ._num_samples_per_network_list ):
478- network_flatten_ids = flatten_ids [current_nid : current_nid + network_num_samples ]
479-
480- if self .batch_axis_choice == "snapshot" :
481- # with snapshots, we still split by scence to ensure the scenario independence
482- # f_0-> (n_0, t_0), f_1 -> (n_0, t_1), ..., f_T -> (n_0, t_T), f_T+1 -> (n_1, t_0), ...
483- time_dim = self ._roots [i ].time_dim
484- num_scenes = len (network_flatten_ids ) // time_dim
485- left = int (num_scenes * split_ratios [0 ])
486- right = int (left + num_scenes * split_ratios [1 ])
487-
488- left = left * time_dim
489- right = right * time_dim
490-
491- network_train_ids = network_flatten_ids [:left ]
492- network_val_ids = network_flatten_ids [left :right ]
493- network_test_ids = network_flatten_ids [right :]
443+ # if not self.split_per_network or len_of_list == 1:
444+ # left = int(self.length * split_ratios[0])
445+ # right = int(left + self.length * split_ratios[1])
446+
447+ # flatten_ids = np.asarray(list(self._index_map.keys()))
448+
449+ # flatten_ids = flatten_ids.tolist()
450+ # train_ids = flatten_ids[:left]
451+ # val_ids = flatten_ids[left:right]
452+ # test_ids = flatten_ids[right:]
453+ # else:
454+ # to split per network, we compute train/val/test individually
455+ # degree of freedom will be (len_of_list - 1)
456+ expected_train_samples = int (self .length * split_ratios [0 ])
457+ expected_valid_samples = int (self .length * split_ratios [1 ])
458+ expected_test_samples = self .length - expected_train_samples - expected_valid_samples
459+ flatten_ids = np .asarray (list (self ._index_map .keys ()))
460+
461+ num_samples_per_network = num_samples // len (self ._num_samples_per_network_list )
462+ current_nid = 0
463+ for i , network_num_samples in enumerate (self ._num_samples_per_network_list ):
464+ network_flatten_ids = flatten_ids [current_nid : current_nid + network_num_samples ]
465+
466+ if self .batch_axis_choice == "snapshot" :
467+ # with snapshots, we still split by scence to ensure the scenario independence
468+ # f_0-> (n_0, t_0), f_1 -> (n_0, t_1), ..., f_T -> (n_0, t_T), f_T+1 -> (n_1, t_0), ...
469+ time_dim = self ._roots [i ].time_dim
470+ num_scenes = len (network_flatten_ids ) // time_dim
471+ left = int (num_scenes * split_ratios [0 ])
472+ right = int (left + num_scenes * split_ratios [1 ])
473+
474+ left = left * time_dim
475+ right = right * time_dim
476+
477+ network_train_ids = network_flatten_ids [:left ]
478+ network_val_ids = network_flatten_ids [left :right ]
479+ network_test_ids = network_flatten_ids [right :]
494480
495- else :
496- left = int (network_num_samples * split_ratios [0 ])
497- right = int (left + network_num_samples * split_ratios [1 ])
498- network_train_ids = network_flatten_ids [:left ]
499- network_val_ids = network_flatten_ids [left :right ]
500- network_test_ids = network_flatten_ids [right :]
501-
502- if i == len_of_list - 1 :
503- network_train_ids = network_train_ids [: expected_train_samples - len (train_ids )]
504- network_val_ids = network_val_ids [: expected_valid_samples - len (val_ids )]
505- network_test_ids = network_test_ids [: expected_test_samples - len (test_ids )]
506-
507- if self .split_fixed :
508- selected_trains = int (num_samples_per_network * split_ratios [0 ])
509- selected_vals = int (num_samples_per_network * split_ratios [1 ])
510- selected_test = num_samples_per_network - selected_trains - selected_vals
511- network_train_ids = network_train_ids [:selected_trains ]
512- network_val_ids = network_val_ids [:selected_vals ]
513- network_test_ids = network_test_ids [:selected_test ]
514-
515- train_ids .extend (network_train_ids .tolist ())
516- val_ids .extend (network_val_ids .tolist ())
517- test_ids .extend (network_test_ids .tolist ())
518- current_nid += network_num_samples
481+ else :
482+ left = int (network_num_samples * split_ratios [0 ])
483+ right = int (left + network_num_samples * split_ratios [1 ])
484+ network_train_ids = network_flatten_ids [:left ]
485+ network_val_ids = network_flatten_ids [left :right ]
486+ network_test_ids = network_flatten_ids [right :]
487+
488+ if i == len_of_list - 1 :
489+ network_train_ids = network_train_ids [: expected_train_samples - len (train_ids )]
490+ network_val_ids = network_val_ids [: expected_valid_samples - len (val_ids )]
491+ network_test_ids = network_test_ids [: expected_test_samples - len (test_ids )]
492+
493+ train_ids .extend (network_train_ids .tolist ())
494+ val_ids .extend (network_val_ids .tolist ())
495+ test_ids .extend (network_test_ids .tolist ())
496+ current_nid += network_num_samples
519497
520498 return train_ids , val_ids , test_ids
521499
@@ -529,54 +507,32 @@ def compute_indices(
529507 flatten_index = 0
530508 self .load_roots (zip_file_paths , input_paths )
531509
532- root_sizes : list [int ] = [r .compute_first_size () for r in self ._roots ]
533- if self .batch_axis_choice == "scene" :
534- sum_of_root_sizes = sum (root_sizes )
535- elif self .batch_axis_choice == "temporal" :
536- sum_of_root_sizes = sum ([r .time_dim for r in self ._roots ])
537- else : # snapshot
538- sum_of_root_sizes = sum ([r .time_dim * root_sizes [i ] for i , r in enumerate (self ._roots )])
539-
540- total_num_samples : int = sum_of_root_sizes # if self.num_records is None else min(sum_of_root_sizes, self.num_records)
541- num_samples_per_network = total_num_samples // len (self ._roots )
510+ # root_sizes: list[int] = [r.compute_first_size() for r in self._roots]
511+ # if self.batch_axis_choice == "scene":
512+ # sum_of_root_sizes = sum(root_sizes)
513+ # elif self.batch_axis_choice == "temporal":
514+ # sum_of_root_sizes = sum([r.time_dim for r in self._roots])
515+ # else: # snapshot
516+ # sum_of_root_sizes = sum([r.time_dim * root_sizes[i] for i, r in enumerate(self._roots)])
542517
543518 for network_index , root in enumerate (self ._roots ):
544519 if self .batch_axis_choice == "scene" :
545520 # arr WILL have shape <merged>(#scenes, #nodes_or_#links, #statics + time_dims * #dynamics)
546- if self .split_fixed :
547- relative_scene_ids = np .arange (root_sizes [network_index ])
548- relative_scene_ids = relative_scene_ids [:num_samples_per_network ]
549- num_samples = len (relative_scene_ids ) # min(num_samples_per_network, root_sizes[network_index])
550- else :
551- num_samples = root .compute_first_size () #if self.num_records is None else min(self.num_records, root.compute_first_size())
552- relative_scene_ids = np .arange (num_samples )
521+ num_samples = root .compute_first_size () # if self.num_records is None else min(self.num_records, root.compute_first_size())
522+ relative_scene_ids = np .arange (num_samples )
553523 tuples = (relative_scene_ids , None )
554524 elif self .batch_axis_choice == "temporal" :
555- if self .split_fixed :
556- relative_time_ids = np .arange (root .time_dim )
557- relative_time_ids = relative_time_ids [:num_samples_per_network ]
558- num_samples = len (relative_time_ids )
559- else :
560- num_samples = root .time_dim
561- relative_time_ids = np .arange (num_samples )
525+ num_samples = root .time_dim
526+ relative_time_ids = np .arange (num_samples )
562527 tuples = (None , relative_time_ids )
563528
564529 elif self .batch_axis_choice == "snapshot" :
565- if self .split_fixed :
566- time_dim = root .time_dim
567- num_scenes = root_sizes [network_index ]
568- relative_scene_ids = np .arange (num_scenes ).repeat (time_dim ) # .reshape([-1, 1])
569- relative_time_ids = np .tile (np .arange (time_dim ), reps = num_scenes ) # .reshape([-1, 1])
570- relative_scene_ids = relative_scene_ids [:num_samples_per_network ]
571- relative_time_ids = relative_time_ids [:num_samples_per_network ]
572- num_samples = len (relative_scene_ids )
573- else :
574- num_scenes = root .compute_first_size () #if self.num_records is None else min(self.num_records, root.compute_first_size())
575- time_dim = root .time_dim
576- relative_scene_ids = np .arange (num_scenes ).repeat (time_dim ) # .reshape([-1, 1])
577- relative_time_ids = np .tile (np .arange (time_dim ), reps = num_scenes ) # .reshape([-1, 1])
578- tuples = (relative_scene_ids , relative_time_ids )
579- num_samples = len (relative_scene_ids )
530+ num_scenes = root .compute_first_size () # if self.num_records is None else min(self.num_records, root.compute_first_size())
531+ time_dim = root .time_dim
532+ relative_scene_ids = np .arange (num_scenes ).repeat (time_dim ) # .reshape([-1, 1])
533+ relative_time_ids = np .tile (np .arange (time_dim ), reps = num_scenes ) # .reshape([-1, 1])
534+ tuples = (relative_scene_ids , relative_time_ids )
535+ num_samples = len (relative_scene_ids )
580536 tuples = (relative_scene_ids , relative_time_ids )
581537 else :
582538 raise NotImplementedError
0 commit comments