Skip to content

Commit 2e58e89

Browse files
committed
chore: minimal example
1 parent a8add93 commit 2e58e89

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

tut_interface.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from ditec_wdn_dataset.hf.dataset import GidaV7
1313
from ditec_wdn_dataset.utils.configs import GidaConfig
1414

15-
from torch_geometric.data import Batch
16-
1715

1816
def tutorial_v6(gida_yaml_path: str) -> list[GidaV6]:
1917
gida_config = GidaConfig()
@@ -97,5 +95,27 @@ def tutorial_v7(gida_yaml_path: str) -> list[GidaV7]:
9795

9896

9997
if __name__ == "__main__":
100-
tutorial_v6("ditec_wdn_dataset/arguments/test_data_interface_v6_config.yaml")
101-
tutorial_v7("ditec_wdn_dataset/arguments/test_data_interface_v7_config.yaml")
98+
# they are minimal examples loaded from configs
99+
# tutorial_v6("ditec_wdn_dataset/arguments/test_data_interface_v6_config.yaml")
100+
# tutorial_v7("ditec_wdn_dataset/arguments/test_data_interface_v7_config.yaml")
101+
102+
# this is a minimal example for the data interface
103+
full_gida = GidaV7(
104+
wdn_names=["CTOWN_1GB_24H"],
105+
node_attrs=["pressure", GidaV7.Node_Elevation],
106+
edge_attrs=["pipe_diameter"],
107+
edge_label_attrs=[], # keep it empty if unsed,
108+
label_attrs=["demand"],
109+
num_records=100, # keep it small to prevent OOM
110+
batch_axis_choice="scene", # record unit (scenario)
111+
verbose=True, # for more details
112+
)
113+
print(len(full_gida.train_ids))
114+
train_set = full_gida.get_set(full_gida.train_ids)
115+
print(next(iter(train_set)))
116+
print(len(full_gida.val_ids))
117+
valid_set = full_gida.get_set(full_gida.val_ids)
118+
print(next(iter(valid_set)))
119+
print(len(full_gida.test_ids))
120+
test_set = full_gida.get_set(full_gida.test_ids)
121+
print(next(iter(test_set)))

0 commit comments

Comments
 (0)