Skip to content

Commit 9630a2b

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Use to_cotangent_aval() in SymbolicZero check in _flatten_bwd
PiperOrigin-RevId: 845603376
1 parent 6940903 commit 9630a2b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/custom_derivatives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,7 @@ def append(x, d):
945945
if ct is zero or getattr(a.to_tangent_aval(), 'dtype') == dtypes.float0:
946946
results.append(Zero(a.to_tangent_aval()))
947947
elif type(ct) is SymbolicZero:
948-
if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval):
948+
if not core.typecompat(a.to_cotangent_aval(), a_ := ct.aval):
949949
msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype "
950950
"that does not match the corresponding input tangent shape/dtype: "
951951
f"at output{keystr(kp)} the SymbolicZero had shape/dtype "

0 commit comments

Comments
 (0)