Skip to content

Commit 8ce3512

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix tree equality error message in si_vjp
PiperOrigin-RevId: 843950969
1 parent 0d58415 commit 8ce3512

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2308,7 +2308,7 @@ def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, out_known,
23082308
cts_flat, out_tree_ = tree_flatten(ct)
23092309
if out_tree_ != out_tree:
23102310
raise ValueError(f"unexpected tree structure of argument to vjp function: "
2311-
f"got {out_tree}, but expected to match {out_tree_}")
2311+
f"got {out_tree_}, but expected to match {out_tree}")
23122312
for arg, aval in zip(cts_flat, out_primal_avals):
23132313
ct_aval = shaped_abstractify(arg)
23142314
ct_aval_expected = aval.to_cotangent_aval()

0 commit comments

Comments
 (0)