diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index a20d5648217b..bbe1a1f9bdf9 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -155,6 +155,7 @@ def __call__( def register_discharge_rule(prim: core.Primitive): def register(f: DischargeRule): _discharge_rules[prim] = f + return f return register @@ -630,6 +631,27 @@ def _closed_call_discharge_rule( assert next(ref_vals_iter, sentinel) is sentinel return new_invals, out_vals +@register_discharge_rule(core.call_p) +def _call_discharge_rule( + in_avals: Sequence[core.AbstractValue], _,*args, + call_jaxpr: core.Jaxpr): + closed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) + discharged_closed_jaxpr, num_outs, fun = _cached_closed_jaxpr_discharge( + closed_call_jaxpr) + discharged_call_jaxpr = discharged_closed_jaxpr.jaxpr + discharged_consts = discharged_closed_jaxpr.consts + discharged_call_jaxpr = pe.convert_constvars_jaxpr(discharged_call_jaxpr) + out_and_ref_vals = core.closed_call_p.bind(fun, *discharged_consts, *args, + call_jaxpr=discharged_call_jaxpr) + out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) + ref_vals_iter = iter(ref_vals) + new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) + else None for aval in in_avals) + sentinel = object() + assert next(ref_vals_iter, sentinel) is sentinel + return new_invals, out_vals + + # # `run_state` run_state_p = core.Primitive("run_state") diff --git a/tests/state_test.py b/tests/state_test.py index c75ffb2043e4..16fbaecb8e5a 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -1385,6 +1385,25 @@ def body(_, x_ref): self.assertEqual(a.shape, ()) self.assertEqual(b.shape, (3,)) + @parameterized.named_parameters( + ("call_primitive", core.call_p), + ("closed_call_primitive", core.closed_call_p), + ) + def test_call_primitive_discharges(self, prim): + + def g(y_ref, x): + x_ref = jax.new_ref(x) + y_ref[...] = jnp.exp(x_ref[...]) + return [jax.freeze(y_ref)] + + def f(x): + y_ref = jax.new_ref(jnp.zeros_like(x)) + g_ = partial(g, y_ref) + return prim.bind( + lu.wrap_init(g_, debug_info=api_util.debug_info("f", g, (x,), {})), x + )[0] + out = f(4.) + np.testing.assert_array_equal(out, jnp.exp(4.)) class GeneralRefTest(jtu.JaxTestCase):