Skip to content

Commit 6395b5f

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Remove jax_custom_vjp_disable_shape_check config option
PiperOrigin-RevId: 845005784
1 parent e914ced commit 6395b5f

File tree

3 files changed

+1
-29
lines changed

3 files changed

+1
-29
lines changed

jax/_src/config.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,13 +1819,6 @@ def _validate_default_device(val):
18191819
upgrade=False,
18201820
help='Temporary workaround to disable an error check in vmap-of-shmap.')
18211821

1822-
# TODO(mattjj): remove once we land mutable array plumbing, or face great shame
1823-
custom_vjp_disable_shape_check = bool_state(
1824-
name='jax_custom_vjp_disable_shape_check',
1825-
default=False,
1826-
upgrade=True,
1827-
help='Disable the check from #19009 to enable some custom_vjp hacks.')
1828-
18291822
mutable_array_checks = bool_state(
18301823
name='jax_mutable_array_checks',
18311824
default=True,

jax/_src/custom_derivatives.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -966,8 +966,7 @@ def append(x, d):
966966
else:
967967
if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct)) and
968968
not _ref_typecompat(a.to_tangent_aval(), a_) and
969-
not (_temporary_dtype_exception(a, a_) or
970-
_temporary_shape_exception(a, a_))):
969+
not _temporary_dtype_exception(a, a_)):
971970
msg = ("Custom VJP bwd rule must produce an output with the same "
972971
"shape/dtypes as the args tuple of the primal function, but at "
973972
f"output{keystr(kp)} the bwd rule produced an output of "
@@ -990,9 +989,6 @@ def _temporary_dtype_exception(a, a_) -> bool:
990989
dtypes.issubdtype(a.dtype, dtypes.np.inexact)))
991990
return False
992991

993-
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
994-
def _temporary_shape_exception(a, a_) -> bool:
995-
return config.custom_vjp_disable_shape_check.value
996992

997993
class CustomVJPCallPrimitive(core.Primitive):
998994
multiple_results = True

tests/custom_api_test.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2971,23 +2971,6 @@ def foo_bwd(_, g):
29712971
r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'):
29722972
jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4))
29732973

2974-
def test_bwd_rule_shape_mismatch_disable(self):
2975-
# TODO(mattjj): remove this test when the config option is removed
2976-
@jax.custom_vjp
2977-
def foo(x, y):
2978-
return x
2979-
2980-
def foo_fwd(x, y):
2981-
return x, None
2982-
2983-
def foo_bwd(_, g):
2984-
return jnp.zeros(3), jnp.zeros(3)
2985-
2986-
foo.defvjp(foo_fwd, foo_bwd)
2987-
2988-
with config.custom_vjp_disable_shape_check(True):
2989-
jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4))
2990-
29912974
def test_bwd_rule_can_produce_list_or_tuple(self):
29922975
@jax.custom_vjp
29932976
def f(x, y):

0 commit comments

Comments
 (0)