Skip to content

Commit 6d72e54

Browse files
committed
fix train encoder depths script
1 parent c7e6bf7 commit 6d72e54

File tree

1 file changed

+3
-12
lines changed

1 file changed

+3
-12
lines changed

train/scripts/train_ed.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,6 @@
3838
help="encoder/decoder depth",
3939
default=12,
4040
)
41-
parser.add_argument(
42-
"--grad_accum",
43-
"-g",
44-
type=int,
45-
help="gradient accumulation steps; must be a multiple of the world size",
46-
default=1,
47-
)
4841
args = parser.parse_args()
4942

5043
if args.xpu:
@@ -90,8 +83,6 @@
9083
RANK = int(os.environ["RANK"])
9184
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
9285

93-
assert args.grad_accum % WORLD_SIZE == 0
94-
9586
def main(download_path: str, encoder_depth: int, xpu: bool = False):
9687
if xpu:
9788
comms_backend = "ccl"
@@ -149,8 +140,8 @@ def main(download_path: str, encoder_depth: int, xpu: bool = False):
149140
data_path=download_path,
150141
t=1,
151142
static_data=Path("static.nc"),
152-
surface_data=Path("2023-01-surface-level.nc"),
153-
atmos_data=Path("2023-01-atmospheric.nc"),
143+
surface_data=Path("2023-01-surface-level-34.nc"),
144+
atmos_data=Path("2023-01-atmospheric-34.nc"),
154145
)
155146
sampler = DistributedSampler(dataset)
156147

@@ -164,7 +155,7 @@ def main(download_path: str, encoder_depth: int, xpu: bool = False):
164155

165156
times = []
166157

167-
n_batches_per_optim = args.grad_accum // WORLD_SIZE
158+
n_batches_per_optim = 8 / WORLD_SIZE
168159

169160
time_start = time.time()
170161
for batch, (X, y) in enumerate(data_loader):

0 commit comments

Comments
 (0)