Skip to content

Commit 254918c

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix ct_check to account for None cotangents too
PiperOrigin-RevId: 845824336
1 parent 579009c commit 254918c

File tree

1 file changed

+3
-3
lines changed
  • jax/_src/interpreters

1 file changed

+3
-3
lines changed

jax/_src/interpreters/ad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,12 +421,12 @@ def __init__(self, aval, ref=None):
421421

422422
def accum(self, x):
423423
assert x is not Zero
424-
ct_check(self, x)
425424
if isinstance(x, Zero) or x is None:
426425
return
427-
elif self.ref is None:
426+
if self.ref is None:
428427
self.ref = core.new_ref(x)
429428
else:
429+
ct_check(self, x)
430430
self.ref.addupdate(x)
431431

432432
def freeze(self):
@@ -449,8 +449,8 @@ def __init__(self, aval, val=None):
449449
self.val = Zero(aval) if val is None else val
450450

451451
def accum(self, x):
452-
ct_check(self, x)
453452
if x is not None:
453+
ct_check(self, x)
454454
self.val = add_tangents(self.val, x)
455455

456456
def freeze(self):

0 commit comments

Comments
 (0)