3838 help = "encoder/decoder depth" ,
3939 default = 12 ,
4040)
41- parser .add_argument (
42- "--grad_accum" ,
43- "-g" ,
44- type = int ,
45- help = "gradient accumulation steps; must be a multiple of the world size" ,
46- default = 1 ,
47- )
4841args = parser .parse_args ()
4942
5043if args .xpu :
9083 RANK = int (os .environ ["RANK" ])
9184 LOCAL_RANK = int (os .environ ["LOCAL_RANK" ])
9285
93- assert args .grad_accum % WORLD_SIZE == 0
94-
9586def main (download_path : str , encoder_depth : int , xpu : bool = False ):
9687 if xpu :
9788 comms_backend = "ccl"
@@ -149,8 +140,8 @@ def main(download_path: str, encoder_depth: int, xpu: bool = False):
149140 data_path = download_path ,
150141 t = 1 ,
151142 static_data = Path ("static.nc" ),
152- surface_data = Path ("2023-01-surface-level.nc" ),
153- atmos_data = Path ("2023-01-atmospheric.nc" ),
143+ surface_data = Path ("2023-01-surface-level-34 .nc" ),
144+ atmos_data = Path ("2023-01-atmospheric-34 .nc" ),
154145 )
155146 sampler = DistributedSampler (dataset )
156147
@@ -164,7 +155,7 @@ def main(download_path: str, encoder_depth: int, xpu: bool = False):
164155
165156 times = []
166157
167- n_batches_per_optim = args . grad_accum / / WORLD_SIZE
158+ n_batches_per_optim = 8 / WORLD_SIZE
168159
169160 time_start = time .time ()
170161 for batch , (X , y ) in enumerate (data_loader ):
0 commit comments