Skip to content

Commit 1a7c900

Browse files
authored
Update (#176)
[ghstack-poisoned]
1 parent feee51a commit 1a7c900

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tests/test_aot_eager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def test_aot_eager_bitwise_equivalent(llama3_debug_model):
5858
x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda")
5959
torch.manual_seed(3999)
6060
r1 = llama3_debug_model(x)
61+
grads1 = torch.autograd.grad(r1.sum(), llama3_debug_model.parameters())
6162
torch.manual_seed(3999)
6263
r2 = torch.compile(backend="aot_eager")(llama3_debug_model)(x)
64+
grads2 = torch.autograd.grad(r2.sum(), llama3_debug_model.parameters())
6365
assert torch.equal(r1, r2) # bitwise equal
66+
for g1, g2 in zip(grads1, grads2):
67+
assert torch.equal(g1, g2)

0 commit comments

Comments
 (0)