|
12 | 12 | from ditec_wdn_dataset.hf.dataset import GidaV7 |
13 | 13 | from ditec_wdn_dataset.utils.configs import GidaConfig |
14 | 14 |
|
15 | | -from torch_geometric.data import Batch |
16 | | - |
17 | 15 |
|
18 | 16 | def tutorial_v6(gida_yaml_path: str) -> list[GidaV6]: |
19 | 17 | gida_config = GidaConfig() |
@@ -97,5 +95,27 @@ def tutorial_v7(gida_yaml_path: str) -> list[GidaV7]: |
97 | 95 |
|
98 | 96 |
|
99 | 97 | if __name__ == "__main__": |
100 | | - tutorial_v6("ditec_wdn_dataset/arguments/test_data_interface_v6_config.yaml") |
101 | | - tutorial_v7("ditec_wdn_dataset/arguments/test_data_interface_v7_config.yaml") |
| 98 | + # they are minimal examples loaded from configs |
| 99 | + # tutorial_v6("ditec_wdn_dataset/arguments/test_data_interface_v6_config.yaml") |
| 100 | + # tutorial_v7("ditec_wdn_dataset/arguments/test_data_interface_v7_config.yaml") |
| 101 | + |
| 102 | + # this is a minimal example for the data interface |
| 103 | + full_gida = GidaV7( |
| 104 | + wdn_names=["CTOWN_1GB_24H"], |
| 105 | + node_attrs=["pressure", GidaV7.Node_Elevation], |
| 106 | + edge_attrs=["pipe_diameter"], |
| 107 | + edge_label_attrs=[], # keep it empty if unsed, |
| 108 | + label_attrs=["demand"], |
| 109 | + num_records=100, # keep it small to prevent OOM |
| 110 | + batch_axis_choice="scene", # record unit (scenario) |
| 111 | + verbose=True, # for more details |
| 112 | + ) |
| 113 | + print(len(full_gida.train_ids)) |
| 114 | + train_set = full_gida.get_set(full_gida.train_ids) |
| 115 | + print(next(iter(train_set))) |
| 116 | + print(len(full_gida.val_ids)) |
| 117 | + valid_set = full_gida.get_set(full_gida.val_ids) |
| 118 | + print(next(iter(valid_set))) |
| 119 | + print(len(full_gida.test_ids)) |
| 120 | + test_set = full_gida.get_set(full_gida.test_ids) |
| 121 | + print(next(iter(test_set))) |
0 commit comments