Skip to content

Commit 17d6070

Browse files
committed
add more timings to train.py
1 parent e69dabc commit 17d6070

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

train/scripts/train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,23 +179,29 @@ def main(download_path: str, shard: bool, xpu: bool = False):
179179
with torch.autocast(device_type=device_type):
180180
print("performing forward pass...")
181181
pred = model(X)
182+
print(f"finished model forward: {time.time()-time_start}")
182183

183184
# only one of these is necessary
184185
pred = pred.to(device)
186+
print(f"finished pred to device: {time.time()-time_start}")
185187
y = y.to(device)
188+
print(f"finished y to device: {time.time()-time_start}")
186189

187190
# mean absolute error of one variable
188191
print("calculating loss...")
189192

190193
# Todo: Are pred's of type PyTree and does it matter?
191194
loss = mae(pred, y)
192-
195+
print(f"finished loss calc: {time.time()-time_start}")
196+
193197
print("performing backward pass...")
194198
loss.backward()
199+
print(f"finished loss backward: {time.time()-time_start}")
195200

196201
if batch % n_batches_per_optim == 0:
197202
print("optimizing...")
198203
optimizer.step()
204+
print(f"finished optimizer step: {time.time()-time_start}")
199205

200206
time_end = time.time()
201207
print(f"Time for 1 iteration: {time_end - time_start}")

0 commit comments

Comments
 (0)