66import re
77import time
88import warnings
9+ from datetime import datetime as dt
910from pathlib import Path
1011
1112warnings .filterwarnings (
@@ -96,6 +97,7 @@ def main(download_path: str, shard: bool, xpu: bool = False):
9697 device_type = "cuda"
9798
9899 time_start_total = time .time ()
100+ print (f"Script start time: { dt .now ()} " )
99101
100102 print ("Initialising process group with backend" , comms_backend , flush = True )
101103 # ToDo Run 2 or more processes.
@@ -108,6 +110,7 @@ def main(download_path: str, shard: bool, xpu: bool = False):
108110 device = f"{ device_type } :{ LOCAL_RANK } "
109111 print (f"Using { device = } " )
110112
113+ print (f"Start time loading model: { dt .now ()} " )
111114 print ("loading model..." )
112115 model = Aurora (
113116 use_lora = False , # Model was not fine-tuned.
@@ -116,6 +119,7 @@ def main(download_path: str, shard: bool, xpu: bool = False):
116119 model .load_checkpoint ("microsoft/aurora" , "aurora-0.25-pretrained.ckpt" )
117120 if not xpu :
118121 torch .cuda .set_device (LOCAL_RANK )
122+ print (f"End time loading model: { dt .now ()} " )
119123
120124 download_path = Path (download_path )
121125
@@ -139,14 +143,20 @@ def main(download_path: str, shard: bool, xpu: bool = False):
139143 # AdamW, as used in the paper.
140144 optimizer = torch .optim .AdamW (model .parameters ())
141145
146+ time_start_loading_data = time .time ()
147+ print (f"Start time loading data: { dt .now ()} " )
142148 print ("loading data..." )
143149 dataset = AuroraDataset (
144150 data_path = download_path ,
145151 t = 1 ,
146152 static_data = Path ("static.nc" ),
147- surface_data = Path ("2023-01-surface-level.nc" ),
148- atmos_data = Path ("2023-01-atmospheric.nc" ),
153+ surface_data = Path ("2023-01-surface-level-34 .nc" ),
154+ atmos_data = Path ("2023-01-atmospheric-34 .nc" ),
149155 )
156+ time_end_loading_data = time .time ()
157+ print (f"End time loading data: { dt .now ()} " )
158+ print (f"Time loading data: { time_end_loading_data - time_start_loading_data } " )
159+
150160 sampler = DistributedSampler (dataset )
151161 data_loader = DataLoader (
152162 dataset = dataset ,
@@ -188,6 +198,7 @@ def main(download_path: str, shard: bool, xpu: bool = False):
188198 optimizer .step ()
189199
190200 time_end = time .time ()
201+ print (f"Time for 1 iteration: { time_end - time_start } " )
191202 times .append (time_end - time_start )
192203 time_start = time .time ()
193204
0 commit comments