Skip to content
Merged
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Let us infer the values of the unknown parameters in our model by running MCMC u

>>> nuts_kernel = NUTS(eight_schools)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> rng_key = random.key(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

```
Expand Down Expand Up @@ -111,7 +111,7 @@ The values above 1 for the split Gelman Rubin diagnostic (`r_hat`) indicates tha

>>> nuts_kernel = NUTS(eight_schools_noncentered)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> rng_key = random.key(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
>>> mcmc.print_summary(exclude_deterministic=False) # doctest: +SKIP

Expand Down Expand Up @@ -161,7 +161,7 @@ Now, let us assume that we have a new school for which we have not observed any
... return numpyro.sample('obs', dist.Normal(mu, tau))

>>> predictive = Predictive(new_school, mcmc.get_samples())
>>> samples_predictive = predictive(random.PRNGKey(1))
>>> samples_predictive = predictive(random.key(1))
>>> print(np.mean(samples_predictive['obs'])) # doctest: +SKIP
3.9886456

Expand Down Expand Up @@ -286,18 +286,18 @@ conda install -c conda-forge numpyro

1. Unlike in Pyro, `numpyro.sample('x', dist.Normal(0, 1))` does not work. Why?

You are most likely using a `numpyro.sample` statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key ([PRNGKey](https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.PRNGKey)) to generate samples from. NumPyro's inference algorithms use the [seed](https://num.pyro.ai/en/latest/handlers.html#seed) handler to thread in a random number generator key, behind the scenes.
You are most likely using a `numpyro.sample` statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key ([PRNG Key](https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.key)) to generate samples from. NumPyro's inference algorithms use the [seed](https://num.pyro.ai/en/latest/handlers.html#seed) handler to thread in a random number generator key, behind the scenes.

Your options are:

- Call the distribution directly and provide a `PRNGKey`, e.g. `dist.Normal(0, 1).sample(PRNGKey(0))`
- Provide the `rng_key` argument to `numpyro.sample`. e.g. `numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))`.
- Call the distribution directly and provide a PRNG key, e.g. `dist.Normal(0, 1).sample(key(0))`
- Provide the `rng_key` argument to `numpyro.sample`. e.g. `numpyro.sample('x', dist.Normal(0, 1), rng_key=key(0))`.
- Wrap the code in a `seed` handler, used either as a context manager or as a function that wraps over the original callable. e.g.

```python
with handlers.seed(rng_seed=0): # random.PRNGKey(0) is used
x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNGKey split from the last one
with handlers.seed(rng_seed=0): # random.key(0) is used
x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNG key split from random.key(0)
y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNG key split from the last one
```

, or as a higher order function:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ Other kernels can be used similarly.
... return x, y_obs


>>> rng_key = random.PRNGKey(seed=42)
>>> rng_key = random.key(seed=42)
>>> rng_key, rng_subkey = random.split(rng_key)
>>> x, y_obs = generate_synthetic_data(
... rng_key=rng_subkey, start=0, stop=1, num=80, scale=0.3
Expand Down
4 changes: 2 additions & 2 deletions examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,12 @@ def main(args):
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(random.PRNGKey(0), *data)
mcmc.run(random.key(0), *data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *data)
discrete_samples = predictive(random.key(1), *data)

item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(
discrete_samples["c"].squeeze(-1)
Expand Down
2 changes: 1 addition & 1 deletion examples/ar2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def run_inference(model, args, rng_key, y):
def main(args):
# generate artificial dataset
num_data = args.num_data
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
t = jnp.arange(0, num_data)
y = jnp.sin(t) + random.normal(rng_key, (num_data,)) * 0.1

Expand Down
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def main(args):
for i, model in enumerate(
(fully_pooled, not_pooled, partially_pooled, partially_pooled_with_logit)
):
rng_key, rng_key_predict = random.split(random.PRNGKey(i + 1))
rng_key, rng_key_predict = random.split(random.key(i + 1))
zs = run_inference(model, at_bats, hits, rng_key, args)
predict(model, at_bats, hits, zs, rng_key_predict, player_names)
predict(
Expand Down
2 changes: 1 addition & 1 deletion examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def main(args):
X, Y, X_test = get_data(N=N, D_X=D_X)

# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
rng_key, rng_key_predict = random.split(random.key(0))
samples = run_inference(model, args, rng_key, X, Y, D_H)

# predict Y_test at inputs X_test
Expand Down
2 changes: 1 addition & 1 deletion examples/capture_recapture.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def main(args):
)

model = models[args.model]
rng_key = random.PRNGKey(args.rng_seed)
rng_key = random.key(args.rng_seed)
run_inference(model, capture_history, sex, rng_key, args)


Expand Down
4 changes: 2 additions & 2 deletions examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def model(data, labels, subsample_size=None):


def benchmark_hmc(args, features, labels):
rng_key = random.PRNGKey(1)
rng_key = random.key(1)
start = time.time()
# a MAP estimate at the following source
# https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
Expand Down Expand Up @@ -174,7 +174,7 @@ def benchmark_hmc(args, features, labels):
subsample_size = 1000
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
svi_result = svi.run(random.key(2), 2000, features, labels)
params, losses = svi_result.params, svi_result.losses
plt.plot(losses)
plt.show()
Expand Down
4 changes: 2 additions & 2 deletions examples/cvae-flax/train_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def create_train_state(model, x, learning_rate_fn):
params = model.init(random.PRNGKey(0), x)
params = model.init(random.key(0), x)
tx = optax.adam(learning_rate_fn)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
return state
Expand Down Expand Up @@ -65,7 +65,7 @@ def train_baseline(
):
state = create_train_state(model, train_fetch(0, train_idx)[0], 0.003)

rng = random.PRNGKey(0)
rng = random.key(0)
best_val_loss = jnp.inf
best_state = state
for i in range(n_epochs):
Expand Down
4 changes: 2 additions & 2 deletions examples/cvae-flax/train_cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ def train_cvae(
n_epochs=100,
):
svi, state = create_train_state(
random.PRNGKey(23), model, guide, train_fetch, baseline_params, 0.003
random.key(23), model, guide, train_fetch, baseline_params, 0.003
)

p1 = baseline_params.unfreeze()["params"]["Dense_0"]["kernel"]
p2 = state.optim_state[1][0]["baseline$params"]["Dense_0"]["kernel"]
assert jnp.all(p1 == p2)

rng = random.PRNGKey(0)
rng = random.key(0)
best_val_loss = jnp.inf
best_state = state
for i in range(n_epochs):
Expand Down
4 changes: 2 additions & 2 deletions examples/dais_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def run_svi(rng_key, X, Y, guide_family="AutoDiagonalNormal", K=8):
print("[{}] final elbo: {:.2f}".format(guide_name, final_elbo))

return guide.sample_posterior(
random.PRNGKey(1), params, sample_shape=(args.num_samples,)
random.key(1), params, sample_shape=(args.num_samples,)
)


Expand All @@ -122,7 +122,7 @@ def run_nuts(mcmc_key, args, X, Y):
def main(args):
X, Y = get_data()

rng_keys = random.split(random.PRNGKey(0), 4)
rng_keys = random.split(random.key(0), 4)

# run SVI with an AutoDAIS guide for two values of K
dais8_samples = run_svi(rng_keys[1], X, Y, guide_family="AutoDAIS", K=8)
Expand Down
4 changes: 2 additions & 2 deletions examples/funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def run_inference(model, args, rng_key):


def main(args):
rng_key = random.PRNGKey(0)
rng_key = random.key(0)

# do inference with centered parameterization
print(
Expand All @@ -84,7 +84,7 @@ def main(args):
# collect deterministic sites
reparam_samples = Predictive(
reparam_model, reparam_samples, return_sites=["x", "y"]
)(random.PRNGKey(1))
)(random.key(1))

# make plots
fig, (ax1, ax2) = plt.subplots(
Expand Down
6 changes: 3 additions & 3 deletions examples/gaussian_shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def model(center1, center2, radius, width, enum=False):
def run_inference(args, data):
print("=== Performing Nested Sampling ===")
ns = NestedSampler(model)
ns.run(random.PRNGKey(0), **data, enum=args.enum)
ns.run(random.key(0), **data, enum=args.enum)
ns.print_summary()
# samples obtained from nested sampler are weighted, so
# we need to provide random key to resample from those weighted samples
ns_samples = ns.get_samples(random.PRNGKey(1), num_samples=args.num_samples)
ns_samples = ns.get_samples(random.key(1), num_samples=args.num_samples)

print("\n=== Performing MCMC Sampling ===")
if args.enum:
Expand All @@ -78,7 +78,7 @@ def run_inference(args, data):
num_warmup=args.num_warmup,
num_samples=args.num_samples,
)
mcmc.run(random.PRNGKey(2), **data, enum=args.enum)
mcmc.run(random.key(2), **data, enum=args.enum)
mcmc.print_summary()
mcmc_samples = mcmc.get_samples()

Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main(args):
X, Y, X_test = get_data(N=args.num_data)

# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
rng_key, rng_key_predict = random.split(random.key(0))
samples = run_inference(model, args, rng_key, X, Y)

# do prediction
Expand Down
2 changes: 1 addition & 1 deletion examples/hmcecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(args):
else:
data, obs = (np.random.normal(size=(10, 28)), np.ones(10))

hmcecs_key, hmc_key = random.split(random.PRNGKey(args.rng_seed))
hmcecs_key, hmc_key = random.split(random.key(args.rng_seed))

# choose inner_kernel
if args.inner_kernel == "hmc":
Expand Down
4 changes: 2 additions & 2 deletions examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ def main(args):
supervised_words,
unsupervised_words,
) = simulate_data(
random.PRNGKey(1),
random.key(1),
num_categories=args.num_categories,
num_words=args.num_words,
num_supervised_data=args.num_supervised,
num_unsupervised_data=args.num_unsupervised,
)
print("Starting inference...")
rng_key = random.PRNGKey(2)
rng_key = random.key(2)
start = time.time()
kernel = NUTS(semi_supervised_hmm)
mcmc = MCMC(
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def main(args):

logger.info("Each sequence has shape {}".format(sequences[0].shape))
logger.info("Starting inference...")
rng_key = random.PRNGKey(2)
rng_key = random.key(2)
start = time.time()
kernel = {"nuts": NUTS, "hmc": HMC}[args.kernel](model)
mcmc = MCMC(
Expand Down
6 changes: 3 additions & 3 deletions examples/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def predict(model, args, samples, rng_key, y, n_seasons):

def main(args):
# generate artificial dataset
rng_key, _ = random.split(random.PRNGKey(0))
rng_key, _ = random.split(random.key(0))
T = args.T
t = jnp.linspace(0, T + args.future, (T + args.future) * N_POINTS_PER_UNIT)
y = jnp.sin(2 * np.pi * t) + 0.3 * t + jax.random.normal(rng_key, t.shape) * 0.1
Expand All @@ -157,11 +157,11 @@ def main(args):
t_test = t[-args.future * N_POINTS_PER_UNIT :]

# do inference
rng_key, _ = random.split(random.PRNGKey(1))
rng_key, _ = random.split(random.key(1))
samples = run_inference(holt_winters, args, rng_key, y_train, n_seasons)

# do prediction
rng_key, _ = random.split(random.PRNGKey(2))
rng_key, _ = random.split(random.key(2))
preds = predict(holt_winters, args, samples, rng_key, y_train, n_seasons)
mean_preds = preds.mean(axis=0)
hpdi_preds = hpdi(preds)
Expand Down
4 changes: 2 additions & 2 deletions examples/horseshoe_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def main(args):
X, Y = get_data(N=N, D_X=D_X, response="continuous")

# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
rng_key, rng_key_predict = random.split(random.key(0))
summary = run_inference(model_normal_likelihood, args, rng_key, X, Y)

# lambda should only be large for the first 3 dimensions, which
Expand All @@ -150,7 +150,7 @@ def main(args):
X, Y = get_data(N=4 * N, D_X=D_X, response="binary")

# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
rng_key, rng_key_predict = random.split(random.key(0))
summary = run_inference(model_bernoulli_likelihood, args, rng_key, X, Y)

# lambda should only be large for the first 3 dimensions, which
Expand Down
2 changes: 1 addition & 1 deletion examples/hsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def main(args):
num_chains=args.num_chains,
progress_bar=(not is_sphinxbuild),
)
mcmc.run(jax.random.PRNGKey(0), **data_dict)
mcmc.run(jax.random.key(0), **data_dict)
if not is_sphinxbuild:
mcmc.print_summary()

Expand Down
6 changes: 3 additions & 3 deletions examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from jax import random
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.random import key

import numpyro
from numpyro import optim
Expand All @@ -29,14 +29,14 @@ def guide(data):

def main(args):
# Generate some data.
data = random.normal(PRNGKey(0), shape=(100,)) + 3.0
data = random.normal(key(0), shape=(100,)) + 3.0

# Construct an SVI object so we can do variational inference on our
# model/guide pair.
adam = optim.Adam(args.learning_rate)

svi = SVI(model, guide, adam, Trace_ELBO(num_particles=100))
svi_state = svi.init(PRNGKey(0), data)
svi_state = svi.init(key(0), data)

# Training loop
def body_fn(i, val):
Expand Down
2 changes: 1 addition & 1 deletion examples/mortality.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def main(args):
print_model_shape(model, a, s2, t, lookup, population)

print("Starting inference...")
rng_key = random.PRNGKey(args.rng_seed)
rng_key = random.key(args.rng_seed)
run_inference(model, a, s2, t, lookup, population, deaths, rng_key, args)


Expand Down
12 changes: 5 additions & 7 deletions examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def main(args):
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(random.PRNGKey(0))
mcmc.run(random.key(0))
mcmc.print_summary()
vanilla_samples = mcmc.get_samples()["x"].copy()

Expand All @@ -79,10 +79,10 @@ def main(args):
svi = SVI(dual_moon_model, guide, optim.Adam(0.003), Trace_ELBO())

print("Start training guide...")
svi_result = svi.run(random.PRNGKey(1), args.num_iters)
svi_result = svi.run(random.key(1), args.num_iters)
print("Finish training guide. Extract samples...")
guide_samples = guide.sample_posterior(
random.PRNGKey(2), svi_result.params, sample_shape=(args.num_samples,)
random.key(2), svi_result.params, sample_shape=(args.num_samples,)
)["x"].copy()

print("\nStart NeuTra HMC...")
Expand All @@ -96,7 +96,7 @@ def main(args):
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(random.PRNGKey(3))
mcmc.run(random.key(3))
mcmc.print_summary()
zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"]
print("Transform samples into unwarped space...")
Expand All @@ -108,9 +108,7 @@ def main(args):
# make plots

# guide samples (for plotting)
guide_base_samples = dist.Normal(jnp.zeros(2), 1.0).sample(
random.PRNGKey(4), (1000,)
)
guide_base_samples = dist.Normal(jnp.zeros(2), 1.0).sample(random.key(4), (1000,))
guide_trans_samples = neutra.transform_sample(guide_base_samples)["x"]

x1 = jnp.linspace(-3, 3, 100)
Expand Down
Loading