Skip to content

Commit c7e6bf7

Browse files
committed
updates to use 32 bits of data + more logging
1 parent 5c8a128 commit c7e6bf7

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

baskerville/dawn-comparison/download.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import cdsapi
1010

1111
# Data will be downloaded here.
12-
download_path = Path("../era5-experiments/downloads")
12+
download_path = Path("../../downloads")
1313

1414
c = cdsapi.Client()
1515

@@ -51,7 +51,7 @@
5151
],
5252
"year": "2023",
5353
"month": "01",
54-
"day": ["01", "02", "03", "04", "05", "06", "07", "08"],
54+
"day": ["01", "02", "03", "04", "05", "06", "07", "08", "09"],
5555
"time": ["00:00", "06:00", "12:00", "18:00"],
5656
"format": "netcdf",
5757
},
@@ -89,7 +89,7 @@
8989
],
9090
"year": "2023",
9191
"month": "01",
92-
"day": ["01", "02", "03", "04", "05", "06", "07", "08"],
92+
"day": ["01", "02", "03", "04", "05", "06", "07", "08", "09"],
9393
"time": ["00:00", "06:00", "12:00", "18:00"],
9494
"format": "netcdf",
9595
},

dawn/scripts/era_v_download.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"06",
6565
"07",
6666
"08",
67+
"09",
6768
],
6869
"time": ["00:00", "06:00", "12:00", "18:00"],
6970
"format": "netcdf",
@@ -111,6 +112,7 @@
111112
"06",
112113
"07",
113114
"08",
115+
"09",
114116
],
115117
"time": ["00:00", "06:00", "12:00", "18:00"],
116118
"format": "netcdf",

train/scripts/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import re
77
import time
88
import warnings
9+
from datetime import datetime as dt
910
from pathlib import Path
1011

1112
warnings.filterwarnings(
@@ -96,6 +97,7 @@ def main(download_path: str, shard: bool, xpu: bool = False):
9697
device_type = "cuda"
9798

9899
time_start_total = time.time()
100+
print(f"Script start time: {dt.now()}")
99101

100102
print("Initialising process group with backend", comms_backend, flush=True)
101103
# ToDo Run 2 or more processes.
@@ -108,6 +110,7 @@ def main(download_path: str, shard: bool, xpu: bool = False):
108110
device = f"{device_type}:{LOCAL_RANK}"
109111
print(f"Using {device=}")
110112

113+
print(f"Start time loading model: {dt.now()}")
111114
print("loading model...")
112115
model = Aurora(
113116
use_lora=False, # Model was not fine-tuned.
@@ -116,6 +119,7 @@ def main(download_path: str, shard: bool, xpu: bool = False):
116119
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
117120
if not xpu:
118121
torch.cuda.set_device(LOCAL_RANK)
122+
print(f"End time loading model: {dt.now()}")
119123

120124
download_path = Path(download_path)
121125

@@ -139,14 +143,20 @@ def main(download_path: str, shard: bool, xpu: bool = False):
139143
# AdamW, as used in the paper.
140144
optimizer = torch.optim.AdamW(model.parameters())
141145

146+
time_start_loading_data = time.time()
147+
print(f"Start time loading data: {dt.now()}")
142148
print("loading data...")
143149
dataset = AuroraDataset(
144150
data_path=download_path,
145151
t=1,
146152
static_data=Path("static.nc"),
147-
surface_data=Path("2023-01-surface-level.nc"),
148-
atmos_data=Path("2023-01-atmospheric.nc"),
153+
surface_data=Path("2023-01-surface-level-34.nc"),
154+
atmos_data=Path("2023-01-atmospheric-34.nc"),
149155
)
156+
time_end_loading_data = time.time()
157+
print(f"End time loading data: {dt.now()}")
158+
print(f"Time loading data: {time_end_loading_data - time_start_loading_data}")
159+
150160
sampler = DistributedSampler(dataset)
151161
data_loader = DataLoader(
152162
dataset=dataset,
@@ -188,6 +198,7 @@ def main(download_path: str, shard: bool, xpu: bool = False):
188198
optimizer.step()
189199

190200
time_end = time.time()
201+
print(f"Time for 1 iteration: {time_end - time_start}")
191202
times.append(time_end - time_start)
192203
time_start = time.time()
193204

0 commit comments

Comments
 (0)