3131 help = "path to download directory" ,
3232 default = "../../era5/era_v_inf" ,
3333)
34+ parser .add_argument (
35+ "--encoders" ,
36+ "-e" ,
37+ type = int ,
38+ help = "encoder/decoder depth" ,
39+ default = 12 ,
40+ )
41+ parser .add_argument (
42+ "--grad_accum" ,
43+ "-g" ,
44+ type = int ,
45+ help = "gradient accumulation steps; must be a multiple of the world size" ,
46+ default = 1 ,
47+ )
3448args = parser .parse_args ()
3549
3650if args .xpu :
7690 RANK = int (os .environ ["RANK" ])
7791 LOCAL_RANK = int (os .environ ["LOCAL_RANK" ])
7892
93+ assert args .grad_accum % WORLD_SIZE == 0
7994
80- def main (download_path : str , xpu : bool = False ):
95+ def main (download_path : str , encoder_depth : int , xpu : bool = False ):
8196 if xpu :
8297 comms_backend = "ccl"
8398 device_type = "xpu"
@@ -102,15 +117,17 @@ def main(download_path: str, xpu: bool = False):
102117 model = Aurora (
103118 use_lora = False , # Model was not fine-tuned.
104119 autocast = True , # Use AMP.
105- encoder_depths = (12 , 12 , 12 ),
120+ encoder_depths = (encoder_depth , encoder_depth , encoder_depth ),
106121 encoder_num_heads = (4 , 8 , 16 ),
107- decoder_depths = (12 , 12 , 12 ),
122+ decoder_depths = (encoder_depth , encoder_depth , encoder_depth ),
108123 decoder_num_heads = (16 , 8 , 4 ),
109124 embed_dim = 256 ,
110125 num_heads = 8 ,
111126 )
112127 # can no longer load checkpoint as we have different model size
113128 # model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
129+ if not xpu :
130+ torch .cuda .set_device (LOCAL_RANK )
114131
115132 download_path = Path (download_path )
116133
@@ -147,6 +164,8 @@ def main(download_path: str, xpu: bool = False):
147164
148165 times = []
149166
167+ n_batches_per_optim = args .grad_accum // WORLD_SIZE
168+
150169 time_start = time .time ()
151170 for batch , (X , y ) in enumerate (data_loader ):
152171 print (f"batch { batch } ..." , flush = True )
@@ -170,8 +189,9 @@ def main(download_path: str, xpu: bool = False):
170189 print ("performing backward pass..." , flush = True )
171190 loss .backward ()
172191
173- print ("optimizing..." , flush = True )
174- optimizer .step ()
192+ if batch % n_batches_per_optim == 0 :
193+ print ("optimizing..." )
194+ optimizer .step ()
175195
176196 time_end = time .time ()
177197 times .append (time_end - time_start )
@@ -185,6 +205,9 @@ def main(download_path: str, xpu: bool = False):
185205 avg_time = sum ([sum (t [1 :]) for t in gathered_times ]) / sum (
186206 [len (times [1 :]) for t in gathered_times ]
187207 )
208+ print (
209+ f"Encoder/decoder depth: ({ encoder_depth } , { encoder_depth } , { encoder_depth } )" , flush = True
210+ )
188211 print (
189212 f"Average time per epoch (ignoring first): { avg_time } seconds" , flush = True
190213 )
@@ -206,4 +229,4 @@ def main(download_path: str, xpu: bool = False):
206229 print ("done" , flush = True )
207230
208231
209- main (args .download_path , xpu = args .xpu )
232+ main (args .download_path , args . encoders , xpu = args .xpu )
0 commit comments