diff --git a/test/test_optimizers.py b/test/test_optimizers.py index fd5274697..1e1ddc260 100644 --- a/test/test_optimizers.py +++ b/test/test_optimizers.py @@ -123,4 +123,8 @@ def my_fn(state, g): state = my_fn(state, jnp.ones(10) * 2.0) state = my_fn(state, jnp.ones(10) * 3.0) - assert my_fn_calls == 1 + if uses_value_arg: + # Dtype is different on the first call vs the rest of the calls + assert my_fn_calls in (1, 2) + else: + assert my_fn_calls == 1