We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
to_cotangent_aval()
SymbolicZero
1 parent 6940903 commit 9630a2bCopy full SHA for 9630a2b
jax/_src/custom_derivatives.py
@@ -945,7 +945,7 @@ def append(x, d):
945
if ct is zero or getattr(a.to_tangent_aval(), 'dtype') == dtypes.float0:
946
results.append(Zero(a.to_tangent_aval()))
947
elif type(ct) is SymbolicZero:
948
- if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval):
+ if not core.typecompat(a.to_cotangent_aval(), a_ := ct.aval):
949
msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype "
950
"that does not match the corresponding input tangent shape/dtype: "
951
f"at output{keystr(kp)} the SymbolicZero had shape/dtype "
0 commit comments