Skip to content

Conversation

@Qazalbash
Copy link
Collaborator

As per the official JAX docs, jax.random.PRNGKey is a legacy API and should be replaced with jax.random.key wherever possible 1. This PR replaces calls to jax.random.PRNGKey with jax.random.key, accompanied by necessary documentation.

I have replaced "PRNGKey" with "PRNG key" in the documentation to distinguish PRNG key as a concept from the function jax.random.PRNGKey.

Footnotes

  1. See note in https://docs.jax.dev/en/latest/jax.random.html#prng-keys

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!!

@Qazalbash
Copy link
Collaborator Author

The following test case is failing on optax-v0.2.7 but passing on older versions.

FAILED test/test_optimizers.py::test_numpyrooptim_no_double_jit[chain-args14-kwargs14-True] - assert 1 == 2

I am unable to reproduce the following test cases on my local machine. Maybe they pass in the next CI.

FAILED test/test_pickle.py::test_pickle_hmc[BarkerMH] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc[HMC] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc[NUTS] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc[SA] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc_enumeration[BarkerMH] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc_enumeration[HMC] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc_enumeration[NUTS] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc_enumeration[SA] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_discrete_hmc[DiscreteHMCGibbs] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_discrete_hmc[MixedHMC] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmcecs - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_autoguide[AutoDelta] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_autoguide[AutoDiagonalNormal] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_autoguide[AutoNormal] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_mcmc_pickle_post_warmup - TypeError: cannot pickle 'PRNGKeyArray' object

@Qazalbash
Copy link
Collaborator Author

@fehiepsi, pickling failed again. Can you rerun the CI?

@Qazalbash
Copy link
Collaborator Author

Need #2136 and #2137 for CI to pass.

@fehiepsi
Copy link
Member

Could you add a simple test in test_pickle to check if we can pickle a PRNGKey? if it happens on CI and is unrelated to numpyro, we can report the issue to jax devs.

@fehiepsi
Copy link
Member

It seems the tests are failing consistently. How about using PRNGKey like before in this test?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants