From 152af4c7c9c7c0ca0940755916aeee48b0a4b83f Mon Sep 17 00:00:00 2001 From: = Date: Mon, 23 Jun 2025 15:05:06 -0400 Subject: [PATCH 01/10] draft of MCLMC --- mclmc.py | 207 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 mclmc.py diff --git a/mclmc.py b/mclmc.py new file mode 100644 index 000000000..80bd54b81 --- /dev/null +++ b/mclmc.py @@ -0,0 +1,207 @@ + + +import argparse +from collections import namedtuple +import os + +import matplotlib.pyplot as plt + +import jax +import jax.numpy as jnp +from jax import random + + +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC +from numpyro.infer.mcmc import MCMCKernel +import blackjax +from numpyro.infer.util import initialize_model +from blackjax.util import pytree_size +from blackjax.mcmc.integrators import ( + IntegratorState) + + +FullState = namedtuple("FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"]) + +class MCLMC(MCMCKernel): + """ + Microcanonical Langevin Monte Carlo (MCLMC) kernel. + + :param model: Python callable containing Pyro primitives. + :param step_size: Initial step size for the Langevin dynamics. + :param num_steps: Number of steps to take in each MCMC iteration. + :param integrator_type: Type of integrator to use (e.g. "mclachlan"). + :param diagonal_preconditioning: Whether to use diagonal preconditioning. + :param num_tuning_steps: Number of tuning steps to use. + :param desired_energy_var: Desired energy variance for tuning. + """ + + + + def __init__( + self, + model=None, + desired_energy_var=5e-4, + diagonal_preconditioning=True, + ): + if model is None: + raise ValueError("Model must be specified for MCLMC") + self._model = model + self._diagonal_preconditioning = diagonal_preconditioning + self._desired_energy_var = desired_energy_var + self._init_fn = None + self._sample_fn = None + self._postprocess_fn = None + + @property + def model(self): + return self._model + + @property + def sample_field(self): + return "position" + + def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): + """ + Initialize the MCLMC kernel. + + :param rng_key: Random number generator key + :param num_warmup: Number of warmup steps + :param init_params: Initial parameters + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Initial state + """ + + init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split(rng_key, 4) + + init_params, potential_fn_gen, _, _ = initialize_model( + init_model_key, + self._model, + model_args=(), + dynamic_args=True, + ) + + logdensity_fn = lambda position: -potential_fn_gen()(position) + initial_position = init_params.z + self.logdensity_fn = logdensity_fn + + sampler_state = blackjax.mcmc.mclmc.init( + position=initial_position, + logdensity_fn=self.logdensity_fn, + rng_key=init_state_key, + ) + + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=inverse_mass_matrix, + ) + + self.dim = pytree_size(initial_position) + + # num_steps is a dummy param here + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + num_tuning_integrator_steps, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=100, + state=sampler_state, + rng_key=rng_key_tune, + diagonal_preconditioning=True, + frac_tune3=num_warmup / (3 * 100), + frac_tune2=num_warmup / (3 * 100), + frac_tune1=num_warmup / (3 * 100), + desired_energy_var=5e-4 + ) + + self.adapt_state = blackjax_mclmc_sampler_params + + return FullState(blackjax_state_after_tuning.position, blackjax_state_after_tuning.momentum, blackjax_state_after_tuning.logdensity, blackjax_state_after_tuning.logdensity_grad, run_key) + + + def sample(self, state, model_args, model_kwargs): + """ + Run MCLMC from the given state and return the resulting state. + + :param state: Current state + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Next state after running MCLMC + """ + + mclmc_state = IntegratorState(state.position, state.momentum, state.logdensity, state.logdensity_grad) + rng_key, rng_key_sample = jax.random.split(state.rng_key, 2) + + kernel = blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=self.logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, + ) + + new_state, info = kernel( + rng_key=rng_key_sample, + state=mclmc_state, + step_size=self.adapt_state.step_size, + L=self.adapt_state.L + ) + + return FullState(new_state.position, new_state.momentum, new_state.logdensity, new_state.logdensity_grad, rng_key) + +if __name__ == "__main__": + + def gaussian_2d_model(): + """ + A simple 2D Gaussian model with mean [0, 0] and covariance [[1, 0.5], [0.5, 1]]. + """ + x = numpyro.sample("x", dist.Normal(0.0, 1.0)) + y = numpyro.sample("y", dist.Normal(0.0, 1.0)) + numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array([0.0])) + return x + y + + + def run_inference(model, args, rng_key): + """ + Run MCMC inference on the given model. + + :param model: The model to run inference on + :param args: Command line arguments + :param rng_key: Random number generator key + :return: MCMC object + """ + kernel = MCLMC( + model=model, + diagonal_preconditioning=True, + desired_energy_var=5e-4, + ) + + mcmc = MCMC( + kernel, + num_warmup=1000, + num_samples=1000, + num_chains=1, + progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, + ) + + mcmc.run(rng_key) + mcmc.print_summary(exclude_deterministic=False) + + samples = mcmc.get_samples() + plt.figure(figsize=(8, 8)) + plt.scatter(samples['x'], samples['y'], alpha=0.5) + plt.xlabel('x') + plt.ylabel('y') + plt.title('MCLMC samples from 2D Gaussian') + plt.grid(True) + plt.savefig('mclmc_samples.png') + plt.close() + + return mcmc + + + rng_key = random.PRNGKey(0) + mcmc = run_inference(gaussian_2d_model, args=None, rng_key=rng_key) + From a1e7e0ba510c52791cab450355d2fdc3540f0d3d Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Tue, 20 Jan 2026 23:27:14 +0100 Subject: [PATCH 02/10] feat: Add Microcanonical Langevin Monte Carlo (MCLMC) kernel Add MCLMC inference algorithm as a new MCMCKernel that wraps blackjax's MCLMC implementation. This provides an alternative gradient-based MCMC method to NUTS/HMC. Features: - MCLMC kernel with automatic step size and trajectory length tuning - Optional blackjax dependency with informative error message - postprocess_fn for constrained/unconstrained transformations - Diagnostics string for progress bar - Comprehensive test suite References: - Microcanonical Hamiltonian Monte Carlo (arXiv:2212.08549) --- mclmc.py | 207 ------------------------------------ numpyro/infer/mclmc.py | 219 +++++++++++++++++++++++++++++++++++++++ test/infer/test_mclmc.py | 155 +++++++++++++++++++++++++++ 3 files changed, 374 insertions(+), 207 deletions(-) delete mode 100644 mclmc.py create mode 100644 numpyro/infer/mclmc.py create mode 100644 test/infer/test_mclmc.py diff --git a/mclmc.py b/mclmc.py deleted file mode 100644 index 80bd54b81..000000000 --- a/mclmc.py +++ /dev/null @@ -1,207 +0,0 @@ - - -import argparse -from collections import namedtuple -import os - -import matplotlib.pyplot as plt - -import jax -import jax.numpy as jnp -from jax import random - - -import numpyro -import numpyro.distributions as dist -from numpyro.infer import MCMC -from numpyro.infer.mcmc import MCMCKernel -import blackjax -from numpyro.infer.util import initialize_model -from blackjax.util import pytree_size -from blackjax.mcmc.integrators import ( - IntegratorState) - - -FullState = namedtuple("FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"]) - -class MCLMC(MCMCKernel): - """ - Microcanonical Langevin Monte Carlo (MCLMC) kernel. - - :param model: Python callable containing Pyro primitives. - :param step_size: Initial step size for the Langevin dynamics. - :param num_steps: Number of steps to take in each MCMC iteration. - :param integrator_type: Type of integrator to use (e.g. "mclachlan"). - :param diagonal_preconditioning: Whether to use diagonal preconditioning. - :param num_tuning_steps: Number of tuning steps to use. - :param desired_energy_var: Desired energy variance for tuning. - """ - - - - def __init__( - self, - model=None, - desired_energy_var=5e-4, - diagonal_preconditioning=True, - ): - if model is None: - raise ValueError("Model must be specified for MCLMC") - self._model = model - self._diagonal_preconditioning = diagonal_preconditioning - self._desired_energy_var = desired_energy_var - self._init_fn = None - self._sample_fn = None - self._postprocess_fn = None - - @property - def model(self): - return self._model - - @property - def sample_field(self): - return "position" - - def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): - """ - Initialize the MCLMC kernel. - - :param rng_key: Random number generator key - :param num_warmup: Number of warmup steps - :param init_params: Initial parameters - :param model_args: Model arguments - :param model_kwargs: Model keyword arguments - :return: Initial state - """ - - init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split(rng_key, 4) - - init_params, potential_fn_gen, _, _ = initialize_model( - init_model_key, - self._model, - model_args=(), - dynamic_args=True, - ) - - logdensity_fn = lambda position: -potential_fn_gen()(position) - initial_position = init_params.z - self.logdensity_fn = logdensity_fn - - sampler_state = blackjax.mcmc.mclmc.init( - position=initial_position, - logdensity_fn=self.logdensity_fn, - rng_key=init_state_key, - ) - - kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, - inverse_mass_matrix=inverse_mass_matrix, - ) - - self.dim = pytree_size(initial_position) - - # num_steps is a dummy param here - ( - blackjax_state_after_tuning, - blackjax_mclmc_sampler_params, - num_tuning_integrator_steps, - ) = blackjax.mclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=100, - state=sampler_state, - rng_key=rng_key_tune, - diagonal_preconditioning=True, - frac_tune3=num_warmup / (3 * 100), - frac_tune2=num_warmup / (3 * 100), - frac_tune1=num_warmup / (3 * 100), - desired_energy_var=5e-4 - ) - - self.adapt_state = blackjax_mclmc_sampler_params - - return FullState(blackjax_state_after_tuning.position, blackjax_state_after_tuning.momentum, blackjax_state_after_tuning.logdensity, blackjax_state_after_tuning.logdensity_grad, run_key) - - - def sample(self, state, model_args, model_kwargs): - """ - Run MCLMC from the given state and return the resulting state. - - :param state: Current state - :param model_args: Model arguments - :param model_kwargs: Model keyword arguments - :return: Next state after running MCLMC - """ - - mclmc_state = IntegratorState(state.position, state.momentum, state.logdensity, state.logdensity_grad) - rng_key, rng_key_sample = jax.random.split(state.rng_key, 2) - - kernel = blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=self.logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, - inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, - ) - - new_state, info = kernel( - rng_key=rng_key_sample, - state=mclmc_state, - step_size=self.adapt_state.step_size, - L=self.adapt_state.L - ) - - return FullState(new_state.position, new_state.momentum, new_state.logdensity, new_state.logdensity_grad, rng_key) - -if __name__ == "__main__": - - def gaussian_2d_model(): - """ - A simple 2D Gaussian model with mean [0, 0] and covariance [[1, 0.5], [0.5, 1]]. - """ - x = numpyro.sample("x", dist.Normal(0.0, 1.0)) - y = numpyro.sample("y", dist.Normal(0.0, 1.0)) - numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array([0.0])) - return x + y - - - def run_inference(model, args, rng_key): - """ - Run MCMC inference on the given model. - - :param model: The model to run inference on - :param args: Command line arguments - :param rng_key: Random number generator key - :return: MCMC object - """ - kernel = MCLMC( - model=model, - diagonal_preconditioning=True, - desired_energy_var=5e-4, - ) - - mcmc = MCMC( - kernel, - num_warmup=1000, - num_samples=1000, - num_chains=1, - progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, - ) - - mcmc.run(rng_key) - mcmc.print_summary(exclude_deterministic=False) - - samples = mcmc.get_samples() - plt.figure(figsize=(8, 8)) - plt.scatter(samples['x'], samples['y'], alpha=0.5) - plt.xlabel('x') - plt.ylabel('y') - plt.title('MCLMC samples from 2D Gaussian') - plt.grid(True) - plt.savefig('mclmc_samples.png') - plt.close() - - return mcmc - - - rng_key = random.PRNGKey(0) - mcmc = run_inference(gaussian_2d_model, args=None, rng_key=rng_key) - diff --git a/numpyro/infer/mclmc.py b/numpyro/infer/mclmc.py new file mode 100644 index 000000000..36c1371c1 --- /dev/null +++ b/numpyro/infer/mclmc.py @@ -0,0 +1,219 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +import jax + +from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import initialize_model +from numpyro.util import identity + +try: + import blackjax + from blackjax.mcmc.integrators import IntegratorState + from blackjax.util import pytree_size + + _BLACKJAX_AVAILABLE = True +except ImportError: + _BLACKJAX_AVAILABLE = False + blackjax = None + IntegratorState = None + pytree_size = None + +FullState = namedtuple( + "FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"] +) + + +class MCLMC(MCMCKernel): + """ + Microcanonical Langevin Monte Carlo (MCLMC) kernel. + + MCLMC is a gradient-based MCMC algorithm that uses Hamiltonian dynamics + on an extended state space. It requires the `blackjax` package. + + **References:** + + 1. *Microcanonical Hamiltonian Monte Carlo*, + Jakob Robnik, G. Bruno De Luca, Eva Silverstein, Uroš Seljak + https://arxiv.org/abs/2212.08549 + + .. note:: The model must have at least 2 latent dimensions for MCLMC to work + (this is a limitation of the blackjax implementation). + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + :param float desired_energy_var: Target energy variance for step size and + trajectory length tuning. Smaller values lead to more conservative + step sizes. Defaults to 5e-4. + :param bool diagonal_preconditioning: Whether to use diagonal preconditioning + for the mass matrix. Defaults to True. + """ + + def __init__( + self, + model=None, + desired_energy_var=5e-4, + diagonal_preconditioning=True, + ): + if not _BLACKJAX_AVAILABLE: + raise ImportError( + "MCLMC requires the 'blackjax' package. " + "Please install it with: pip install blackjax" + ) + if model is None: + raise ValueError("Model must be specified for MCLMC") + self._model = model + self._diagonal_preconditioning = diagonal_preconditioning + self._desired_energy_var = desired_energy_var + self._init_fn = None + self._sample_fn = None + self._postprocess_fn = None + + @property + def model(self): + return self._model + + @property + def sample_field(self): + return "position" + + @property + def default_fields(self): + return (self.sample_field,) + + def get_diagnostics_str(self, state): + """ + Return a diagnostics string for the progress bar. + """ + return "step_size={:.2e}, L={:.2e}".format( + self.adapt_state.step_size, self.adapt_state.L + ) + + def postprocess_fn(self, args, kwargs): + """ + Get a function that transforms unconstrained values at sample sites to values + constrained to the site's support, in addition to returning deterministic + sites in the model. + + :param args: Arguments to the model. + :param kwargs: Keyword arguments to the model. + """ + if self._postprocess_fn is None: + return identity + return self._postprocess_fn(*args, **kwargs) + + def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): + """ + Initialize the MCLMC kernel. + + :param rng_key: Random number generator key + :param num_warmup: Number of warmup steps + :param init_params: Initial parameters + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Initial state + """ + + init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split( + rng_key, 4 + ) + + init_params, potential_fn_gen, postprocess_fn, _ = initialize_model( + init_model_key, + self._model, + model_args=model_args, + model_kwargs=model_kwargs, + dynamic_args=True, + ) + self._postprocess_fn = postprocess_fn + + def logdensity_fn(position): + return -potential_fn_gen(*model_args, **model_kwargs)(position) + + initial_position = init_params.z + self.logdensity_fn = logdensity_fn + + sampler_state = blackjax.mcmc.mclmc.init( + position=initial_position, + logdensity_fn=self.logdensity_fn, + rng_key=init_state_key, + ) + + def kernel(inverse_mass_matrix): + return blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=inverse_mass_matrix, + ) + + self.dim = pytree_size(initial_position) + + # num_steps is a dummy param here (used for tuning fractions) + num_tuning_steps = 100 + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + _, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_tuning_steps, + state=sampler_state, + rng_key=rng_key_tune, + diagonal_preconditioning=self._diagonal_preconditioning, + frac_tune3=num_warmup / (3 * num_tuning_steps), + frac_tune2=num_warmup / (3 * num_tuning_steps), + frac_tune1=num_warmup / (3 * num_tuning_steps), + desired_energy_var=self._desired_energy_var, + ) + + self.adapt_state = blackjax_mclmc_sampler_params + + return FullState( + blackjax_state_after_tuning.position, + blackjax_state_after_tuning.momentum, + blackjax_state_after_tuning.logdensity, + blackjax_state_after_tuning.logdensity_grad, + run_key, + ) + + def sample(self, state, model_args, model_kwargs): + """ + Run MCLMC from the given state and return the resulting state. + + :param state: Current state + :param model_args: Model arguments + :param model_kwargs: Model keyword arguments + :return: Next state after running MCLMC + """ + + mclmc_state = IntegratorState( + state.position, state.momentum, state.logdensity, state.logdensity_grad + ) + rng_key, rng_key_sample = jax.random.split(state.rng_key, 2) + + kernel = blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=self.logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, + ) + + new_state, info = kernel( + rng_key=rng_key_sample, + state=mclmc_state, + step_size=self.adapt_state.step_size, + L=self.adapt_state.L, + ) + + return FullState( + new_state.position, + new_state.momentum, + new_state.logdensity, + new_state.logdensity_grad, + rng_key, + ) + + def __getstate__(self): + state = self.__dict__.copy() + state["_postprocess_fn"] = None + return state diff --git a/test/infer/test_mclmc.py b/test/infer/test_mclmc.py new file mode 100644 index 000000000..b3fd1af9d --- /dev/null +++ b/test/infer/test_mclmc.py @@ -0,0 +1,155 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from numpy.testing import assert_allclose +import pytest + +from jax import random +import jax.numpy as jnp + +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC +from numpyro.infer.mclmc import MCLMC + + +def test_mclmc_model_required(): + """Test that ValueError is raised when model is None.""" + with pytest.raises(ValueError, match="Model must be specified"): + MCLMC(model=None) + + +def test_mclmc_blackjax_not_installed(monkeypatch): + """Test that ImportError is raised with informative message when blackjax is not installed.""" + import numpyro.infer.mclmc as mclmc_module + + # Temporarily set _BLACKJAX_AVAILABLE to False + monkeypatch.setattr(mclmc_module, "_BLACKJAX_AVAILABLE", False) + + def dummy_model(): + numpyro.sample("x", dist.Normal(0, 1)) + + with pytest.raises(ImportError, match="MCLMC requires the 'blackjax' package"): + MCLMC(model=dummy_model) + + +def test_mclmc_normal(): + """Test MCLMC with a 2D normal distribution. + + Note: MCLMC requires at least 2 dimensions (blackjax limitation). + """ + true_mean = jnp.array([1.0, 2.0]) + true_std = jnp.array([0.5, 1.0]) + num_warmup, num_samples = 1000, 2000 + + def model(): + numpyro.sample("x", dist.Normal(true_mean, true_std).to_event(1)) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert "x" in samples + assert samples["x"].shape == (num_samples, 2) + assert_allclose(jnp.mean(samples["x"], axis=0), true_mean, atol=0.1) + assert_allclose(jnp.std(samples["x"], axis=0), true_std, atol=0.2) + + +def test_mclmc_gaussian_2d(): + """Test MCLMC with a 2D Gaussian model with observation.""" + num_warmup, num_samples = 1000, 1000 + + def model(): + x = numpyro.sample("x", dist.Normal(0.0, 1.0)) + y = numpyro.sample("y", dist.Normal(0.0, 1.0)) + numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array(0.0)) + + kernel = MCLMC( + model=model, + diagonal_preconditioning=True, + desired_energy_var=5e-4, + ) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert "x" in samples + assert "y" in samples + assert samples["x"].shape == (num_samples,) + assert samples["y"].shape == (num_samples,) + # With obs=0, x+y should be close to 0, so means should be near 0 + assert_allclose(jnp.mean(samples["x"]) + jnp.mean(samples["y"]), 0.0, atol=0.2) + + +def test_mclmc_logistic_regression(): + """Test MCLMC with a logistic regression model. + + Note: MCLMC currently doesn't pass model_args, so we use a closure pattern. + """ + N, dim = 1000, 3 + num_warmup, num_samples = 1000, 2000 + + key1, key2, key3 = random.split(random.PRNGKey(0), 3) + data = random.normal(key1, (N, dim)) + true_coefs = jnp.arange(1.0, dim + 1.0) + logits = jnp.sum(true_coefs * data, axis=-1) + labels = dist.Bernoulli(logits=logits).sample(key2) + + # Use closure pattern since MCLMC doesn't pass model_args + def model(): + coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) + logits = jnp.sum(coefs * data, axis=-1) + numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(key3) + samples = mcmc.get_samples() + + assert "coefs" in samples + assert samples["coefs"].shape == (num_samples, dim) + assert_allclose(jnp.mean(samples["coefs"], 0), true_coefs, atol=0.5) + + +def test_mclmc_sample_shape(): + """Test that MCLMC produces samples with expected shapes.""" + num_warmup, num_samples = 500, 500 + + def model(): + numpyro.sample("a", dist.Normal(0, 1)) + numpyro.sample("b", dist.Normal(0, 1).expand([3])) + numpyro.sample("c", dist.Normal(0, 1).expand([2, 4])) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert samples["a"].shape == (num_samples,) + assert samples["b"].shape == (num_samples, 3) + assert samples["c"].shape == (num_samples, 2, 4) From 59d7f417fbe1cf62a7a178475a9a5e454e4d8022 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Tue, 20 Jan 2026 23:36:18 +0100 Subject: [PATCH 03/10] coauthor Co-authored-by: reubenharry From af9f00101eeeec13803eec95aa324d4c9d6fd66f Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Wed, 21 Jan 2026 00:11:49 +0100 Subject: [PATCH 04/10] add blcakjax to test --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index f0bef9f5a..972a9dad1 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ "scikit-learn", "scipy>=1.9", "ty>=0.0.4", + "blackjax>=1.3", ], "dev": [ "dm-haiku>=0.0.14", From a7280f6c7b07904cea460775640dbea603fd42e8 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 27 Feb 2026 22:52:09 +0100 Subject: [PATCH 05/10] empty From f336db0f6aedbf3847a5224370d27131d9a17a45 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 27 Feb 2026 23:37:08 +0100 Subject: [PATCH 06/10] init no blackjax --- numpyro/infer/mclmc.py | 606 ++++++++++++++++++++++++++++++++------- test/infer/test_mclmc.py | 93 ++++-- 2 files changed, 573 insertions(+), 126 deletions(-) diff --git a/numpyro/infer/mclmc.py b/numpyro/infer/mclmc.py index 36c1371c1..1b0929852 100644 --- a/numpyro/infer/mclmc.py +++ b/numpyro/infer/mclmc.py @@ -4,34 +4,491 @@ from collections import namedtuple import jax +from jax.flatten_util import ravel_pytree +import jax.numpy as jnp +from numpyro.diagnostics import effective_sample_size from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import initialize_model from numpyro.util import identity -try: - import blackjax - from blackjax.mcmc.integrators import IntegratorState - from blackjax.util import pytree_size - - _BLACKJAX_AVAILABLE = True -except ImportError: - _BLACKJAX_AVAILABLE = False - blackjax = None - IntegratorState = None - pytree_size = None - +MCLMCState = namedtuple( + "MCLMCState", ["position", "momentum", "logdensity", "logdensity_grad"] +) +MCLMCInfo = namedtuple("MCLMCInfo", ["logdensity", "kinetic_change", "energy_change"]) +MCLMCAdaptationState = namedtuple( + "MCLMCAdaptationState", ["L", "step_size", "inverse_mass_matrix"] +) FullState = namedtuple( "FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"] ) +# First momentum-stage coefficient in the 5-stage McLachlan splitting scheme. +_MCLACHLAN_B1 = 0.1931833275037836 +# Palindromic integrator coefficients for one isokinetic McLachlan update. +_MCLACHLAN_COEFS = (_MCLACHLAN_B1, 0.5, 1 - 2 * _MCLACHLAN_B1, 0.5, _MCLACHLAN_B1) +# When NaNs are detected during adaptation, shrink step size by this factor. +_DELTA_NAN_STEP_SIZE_FACTOR = 0.8 + + +def _pytree_size(pytree): + return sum(jnp.size(leaf) for leaf in jax.tree.leaves(pytree)) + + +def _generate_unit_vector(rng_key, position): + flat_position, unravel_fn = ravel_pytree(position) + sample = jax.random.normal( + rng_key, shape=flat_position.shape, dtype=flat_position.dtype + ) + return unravel_fn(sample / jnp.linalg.norm(sample)) + + +def _incremental_value_update( + expectation, incremental_val, weight=1.0, zero_prevention=0.0 +): + total, average = incremental_val + average = jax.tree.map( + lambda exp, av: jnp.where( + (total * av + weight * exp) == 0.0, + 0.0, + (total * av + weight * exp) / (total + weight + zero_prevention), + ), + expectation, + average, + ) + return total + weight, average + + +def _init_mclmc(position, logdensity_fn, rng_key): + if _pytree_size(position) < 2: + raise ValueError( + "The target distribution must have more than 1 dimension for MCLMC." + ) + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return MCLMCState( + position=position, + momentum=_generate_unit_vector(rng_key, position), + logdensity=logdensity, + logdensity_grad=logdensity_grad, + ) + + +def _position_update(position, kinetic_grad, step_size, coef, logdensity_fn): + new_position = jax.tree.map( + lambda x, grad: x + step_size * coef * grad, + position, + kinetic_grad, + ) + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(new_position) + return new_position, logdensity, logdensity_grad + + +def _normalized_flatten(x, tol=1e-13): + norm = jnp.linalg.norm(x) + return jnp.where(norm > tol, x / norm, x), norm + + +def _esh_dynamics_momentum_update_one_step( + momentum, + logdensity_grad, + step_size, + coef, + inverse_mass_matrix, + previous_kinetic_energy_change=None, +): + sqrt_inverse_mass_matrix = jnp.sqrt(inverse_mass_matrix) + flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) + flatten_grads = flatten_grads * sqrt_inverse_mass_matrix + flatten_momentum, _ = ravel_pytree(momentum) + dims = flatten_momentum.shape[0] + + normalized_gradient, gradient_norm = _normalized_flatten(flatten_grads) + momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) + delta = step_size * coef * gradient_norm / (dims - 1) + zeta = jnp.exp(-delta) + new_momentum_raw = ( + normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) + + 2 * zeta * flatten_momentum + ) + new_momentum_normalized, _ = _normalized_flatten(new_momentum_raw) + next_momentum = unravel_fn(new_momentum_normalized) + kinetic_grad = unravel_fn(new_momentum_normalized * sqrt_inverse_mass_matrix) + kinetic_energy_change = ( + delta + - jnp.log(2.0) + + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) + ) * (dims - 1) + if previous_kinetic_energy_change is not None: + kinetic_energy_change = kinetic_energy_change + previous_kinetic_energy_change + return next_momentum, kinetic_grad, kinetic_energy_change + + +def _isokinetic_mclachlan_step(state, step_size, logdensity_fn, inverse_mass_matrix): + position, momentum, _, logdensity_grad = state + kinetic_energy_change = None + + for i, coef in enumerate(_MCLACHLAN_COEFS[:-1]): + if i % 2 == 0: + momentum, kinetic_grad, kinetic_energy_change = ( + _esh_dynamics_momentum_update_one_step( + momentum=momentum, + logdensity_grad=logdensity_grad, + step_size=step_size, + coef=coef, + inverse_mass_matrix=inverse_mass_matrix, + previous_kinetic_energy_change=kinetic_energy_change, + ) + ) + else: + position, logdensity, logdensity_grad = _position_update( + position=position, + kinetic_grad=kinetic_grad, + step_size=step_size, + coef=coef, + logdensity_fn=logdensity_fn, + ) + + momentum, _, kinetic_energy_change = _esh_dynamics_momentum_update_one_step( + momentum=momentum, + logdensity_grad=logdensity_grad, + step_size=step_size, + coef=_MCLACHLAN_COEFS[-1], + inverse_mass_matrix=inverse_mass_matrix, + previous_kinetic_energy_change=kinetic_energy_change, + ) + return MCLMCState( + position, momentum, logdensity, logdensity_grad + ), kinetic_energy_change + + +def _partially_refresh_momentum(momentum, rng_key, step_size, L): + flat_momentum, unravel_fn = ravel_pytree(momentum) + dim = flat_momentum.shape[0] + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) + z = nu * jax.random.normal( + rng_key, shape=flat_momentum.shape, dtype=flat_momentum.dtype + ) + new_momentum = unravel_fn((flat_momentum + z) / jnp.linalg.norm(flat_momentum + z)) + return jax.lax.cond( + jnp.isinf(L), lambda _: momentum, lambda _: new_momentum, operand=None + ) + + +def _maruyama_step( + init_state, step_size, L, rng_key, logdensity_fn, inverse_mass_matrix +): + key1, key2 = jax.random.split(rng_key) + state = init_state._replace( + momentum=_partially_refresh_momentum( + momentum=init_state.momentum, + rng_key=key1, + L=L, + step_size=step_size * 0.5, + ) + ) + state, kinetic_change = _isokinetic_mclachlan_step( + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + inverse_mass_matrix=inverse_mass_matrix, + ) + state = state._replace( + momentum=_partially_refresh_momentum( + momentum=state.momentum, + rng_key=key2, + L=L, + step_size=step_size * 0.5, + ) + ) + return state, kinetic_change + + +def _handle_nans(previous_state, next_state, info, key): + new_momentum = _generate_unit_vector(key, previous_state.position) + flat_position, _ = ravel_pytree(next_state.position) + flat_momentum, _ = ravel_pytree(next_state.momentum) + nonans = jnp.logical_and( + jnp.all(jnp.isfinite(flat_position)), jnp.all(jnp.isfinite(flat_momentum)) + ) + state, info = jax.lax.cond( + nonans, + lambda: (next_state, info), + lambda: ( + previous_state._replace(momentum=new_momentum), + MCLMCInfo( + logdensity=previous_state.logdensity, + energy_change=jnp.zeros_like(info.energy_change), + kinetic_change=jnp.zeros_like(info.kinetic_change), + ), + ), + ) + return state, info + + +def _build_kernel(logdensity_fn, inverse_mass_matrix): + def kernel(rng_key, state, L, step_size): + kernel_key, nan_key = jax.random.split(rng_key) + next_state, kinetic_change = _maruyama_step( + init_state=state, + step_size=step_size, + L=L, + rng_key=kernel_key, + logdensity_fn=logdensity_fn, + inverse_mass_matrix=inverse_mass_matrix, + ) + energy_change = kinetic_change - next_state.logdensity + state.logdensity + next_state, info = _handle_nans( + previous_state=state, + next_state=next_state, + info=MCLMCInfo( + logdensity=next_state.logdensity, + energy_change=energy_change, + kinetic_change=kinetic_change, + ), + key=nan_key, + ) + return next_state, info + + return kernel + + +def _adaptation_handle_nans( + previous_state, next_state, step_size, step_size_max, kinetic_change, key +): + flat_position, _ = ravel_pytree(next_state.position) + flat_momentum, _ = ravel_pytree(next_state.momentum) + nonans = jnp.logical_and( + jnp.all(jnp.isfinite(flat_position)), jnp.all(jnp.isfinite(flat_momentum)) + ) + state, step_size, kinetic_change = jax.tree.map( + lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (next_state, step_size_max, kinetic_change), + (previous_state, step_size * _DELTA_NAN_STEP_SIZE_FACTOR, 0.0), + ) + state = jax.lax.cond( + jnp.isnan(next_state.logdensity), + lambda: state._replace( + momentum=_generate_unit_vector(key, previous_state.position) + ), + lambda: state, + ) + return nonans, state, step_size, kinetic_change + + +def _make_l_step_size_adaptation( + kernel_fn, + dim, + frac_tune1, + frac_tune2, + diagonal_preconditioning, + desired_energy_var=1e-3, + trust_in_estimate=1.5, + num_effective_samples=150, +): + decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) + + def predictor(previous_state, params, adaptive_state, rng_key): + time, x_average, step_size_max = adaptive_state + rng_key, nan_key = jax.random.split(rng_key) + next_state, info = kernel_fn(params.inverse_mass_matrix)( + rng_key=rng_key, + state=previous_state, + L=params.L, + step_size=params.step_size, + ) + success, state, step_size_max, energy_change = _adaptation_handle_nans( + previous_state=previous_state, + next_state=next_state, + step_size=params.step_size, + step_size_max=step_size_max, + kinetic_change=info.energy_change, + key=nan_key, + ) + xi = jnp.square(energy_change) / (dim * desired_energy_var) + 1e-8 + weight = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate))) + x_average = decay_rate * x_average + weight * ( + xi / jnp.power(params.step_size, 6.0) + ) + time = decay_rate * time + weight + step_size = jnp.power(x_average / time, -1.0 / 6.0) + step_size = jnp.where(step_size < step_size_max, step_size, step_size_max) + params_new = params._replace(step_size=step_size) + return state, params_new, (time, x_average, step_size_max), success + + def step(iteration_state, weight_and_key): + mask, rng_key = weight_and_key + state, params, adaptive_state, streaming_avg = iteration_state + state, params, adaptive_state, success = predictor( + state, params, adaptive_state, rng_key + ) + x = ravel_pytree(state.position)[0] + streaming_avg = _incremental_value_update( + expectation=jnp.array([x, jnp.square(x)]), + incremental_val=streaming_avg, + weight=mask * success * params.step_size, + ) + return (state, params, adaptive_state, streaming_avg), None + + def run_steps(xs, state, params): + return jax.lax.scan( + step, + init=( + state, + params, + (0.0, 0.0, jnp.inf), + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=xs, + )[0] + + def adaptation(state, params, num_steps, rng_key): + num_steps1 = round(num_steps * frac_tune1) + num_steps2 = round(num_steps * frac_tune2) + keys = jax.random.split(rng_key, num_steps1 + num_steps2 + 1) + tune_keys, final_key = keys[:-1], keys[-1] + mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + state, params, _, (_, average) = run_steps((mask, tune_keys), state, params) + L = params.L + inverse_mass_matrix = params.inverse_mass_matrix + if num_steps2 > 1: + x_average, x_squared_average = average[0], average[1] + variances = x_squared_average - jnp.square(x_average) + L = jnp.sqrt(jnp.sum(variances)) + if diagonal_preconditioning: + inverse_mass_matrix = variances + params = params._replace(inverse_mass_matrix=inverse_mass_matrix) + L = jnp.sqrt(dim) + steps = round(num_steps2 / 3) + keys = jax.random.split(final_key, steps) + state, params, _, (_, _) = run_steps((jnp.ones(steps), keys), state, params) + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) + + return adaptation + + +def _make_adaptation_l(kernel, frac, lfactor): + def adaptation_l(state, params, num_steps, rng_key): + num_steps3 = round(num_steps * frac) + keys = jax.random.split(rng_key, num_steps3) + + def step(curr_state, key): + next_state, _ = kernel( + rng_key=key, + state=curr_state, + L=params.L, + step_size=params.step_size, + ) + return next_state, next_state.position + + state, samples = jax.lax.scan(step, init=state, xs=keys) + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + ess = effective_sample_size(flat_samples[None, ...]) + return state, params._replace( + L=lfactor * params.step_size * jnp.mean(num_steps3 / ess) + ) + + return adaptation_l + + +def _mclmc_find_l_and_step_size( + mclmc_kernel, + num_steps, + state, + rng_key, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + desired_energy_var=5e-4, + trust_in_estimate=1.5, + num_effective_samples=150, + diagonal_preconditioning=True, + params=None, + lfactor=0.4, +): + dim = _pytree_size(state.position) + if params is None: + params = MCLMCAdaptationState( + L=jnp.sqrt(dim), + step_size=jnp.sqrt(dim) * 0.25, + inverse_mass_matrix=jnp.ones((dim,)), + ) + + part1_key, part2_key = jax.random.split(rng_key, 2) + num_steps1 = round(num_steps * frac_tune1) + num_steps2 = round(num_steps * frac_tune2) + num_steps2 += diagonal_preconditioning * (num_steps2 // 3) + num_steps3 = round(num_steps * frac_tune3) + + state, params = _make_l_step_size_adaptation( + kernel_fn=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + desired_energy_var=desired_energy_var, + trust_in_estimate=trust_in_estimate, + num_effective_samples=num_effective_samples, + diagonal_preconditioning=diagonal_preconditioning, + )(state, params, num_steps, part1_key) + + total_num_tuning_integrator_steps = num_steps1 + num_steps2 + if num_steps3 >= 2: + state, params = _make_adaptation_l( + kernel=mclmc_kernel(params.inverse_mass_matrix), + frac=frac_tune3, + lfactor=lfactor, + )(state, params, num_steps, part2_key) + total_num_tuning_integrator_steps += num_steps3 + return state, params, total_num_tuning_integrator_steps + class MCLMC(MCMCKernel): """ Microcanonical Langevin Monte Carlo (MCLMC) kernel. - MCLMC is a gradient-based MCMC algorithm that uses Hamiltonian dynamics - on an extended state space. It requires the `blackjax` package. + This kernel implements an isokinetic integrator with stochastic momentum + refreshment. During warmup, it automatically tunes step size, momentum + decoherence length ``L``, and optionally a diagonal preconditioner. + The resulting state can be used with :class:`~numpyro.infer.mcmc.MCMC`. + + Example + ------- + + A minimal 2D model: + + .. code-block:: python + + import jax + import jax.numpy as jnp + import numpyro + import numpyro.distributions as dist + from numpyro.infer import MCMC + from numpyro.infer.mclmc import MCLMC + + def model(): + numpyro.sample("x", dist.Normal(jnp.array([0.0, 0.0]), 1.0).to_event(1)) + + kernel = MCLMC(model=model) + mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, progress_bar=False) + mcmc.run(jax.random.key(0)) + samples = mcmc.get_samples() + + Model with observed data and tuned energy variance: + + .. code-block:: python + + def model(X, y=None): + w = numpyro.sample("w", dist.Normal(jnp.zeros(X.shape[-1]), 1.0)) + logits = X @ w + numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=y) + + kernel = MCLMC( + model=model, + desired_energy_var=5e-4, + diagonal_preconditioning=True, + ) + mcmc = MCMC(kernel, num_warmup=1500, num_samples=1000, progress_bar=False) + mcmc.run(jax.random.key(1), X, y) **References:** @@ -39,15 +496,16 @@ class MCLMC(MCMCKernel): Jakob Robnik, G. Bruno De Luca, Eva Silverstein, Uroš Seljak https://arxiv.org/abs/2212.08549 - .. note:: The model must have at least 2 latent dimensions for MCLMC to work - (this is a limitation of the blackjax implementation). + .. note:: The model must have at least 2 unconstrained latent dimensions. + This limitation comes from the isokinetic MCLMC dynamics. - :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. - :param float desired_energy_var: Target energy variance for step size and - trajectory length tuning. Smaller values lead to more conservative - step sizes. Defaults to 5e-4. - :param bool diagonal_preconditioning: Whether to use diagonal preconditioning - for the mass matrix. Defaults to True. + :param model: Python callable containing NumPyro primitives. + :param float desired_energy_var: Target energy variance used in warmup to tune + step size. Smaller values generally lead to more conservative integration + steps. Defaults to ``5e-4``. + :param bool diagonal_preconditioning: Whether warmup should estimate a diagonal + inverse mass matrix. If ``False``, adaptation uses isotropic scaling. + Defaults to ``True``. """ def __init__( @@ -56,18 +514,11 @@ def __init__( desired_energy_var=5e-4, diagonal_preconditioning=True, ): - if not _BLACKJAX_AVAILABLE: - raise ImportError( - "MCLMC requires the 'blackjax' package. " - "Please install it with: pip install blackjax" - ) if model is None: raise ValueError("Model must be specified for MCLMC") self._model = model self._diagonal_preconditioning = diagonal_preconditioning self._desired_energy_var = desired_energy_var - self._init_fn = None - self._sample_fn = None self._postprocess_fn = None @property @@ -83,42 +534,17 @@ def default_fields(self): return (self.sample_field,) def get_diagnostics_str(self, state): - """ - Return a diagnostics string for the progress bar. - """ return "step_size={:.2e}, L={:.2e}".format( self.adapt_state.step_size, self.adapt_state.L ) def postprocess_fn(self, args, kwargs): - """ - Get a function that transforms unconstrained values at sample sites to values - constrained to the site's support, in addition to returning deterministic - sites in the model. - - :param args: Arguments to the model. - :param kwargs: Keyword arguments to the model. - """ if self._postprocess_fn is None: return identity return self._postprocess_fn(*args, **kwargs) def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): - """ - Initialize the MCLMC kernel. - - :param rng_key: Random number generator key - :param num_warmup: Number of warmup steps - :param init_params: Initial parameters - :param model_args: Model arguments - :param model_kwargs: Model keyword arguments - :return: Initial state - """ - - init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split( - rng_key, 4 - ) - + init_model_key, init_state_key, run_key, tune_key = jax.random.split(rng_key, 4) init_params, potential_fn_gen, postprocess_fn, _ = initialize_model( init_model_key, self._model, @@ -131,80 +557,54 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): def logdensity_fn(position): return -potential_fn_gen(*model_args, **model_kwargs)(position) - initial_position = init_params.z self.logdensity_fn = logdensity_fn - - sampler_state = blackjax.mcmc.mclmc.init( - position=initial_position, + sampler_state = _init_mclmc( + position=init_params.z, logdensity_fn=self.logdensity_fn, rng_key=init_state_key, ) - def kernel(inverse_mass_matrix): - return blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + def kernel_fn(inverse_mass_matrix): + return _build_kernel( + logdensity_fn=self.logdensity_fn, inverse_mass_matrix=inverse_mass_matrix, ) - self.dim = pytree_size(initial_position) - - # num_steps is a dummy param here (used for tuning fractions) num_tuning_steps = 100 - ( - blackjax_state_after_tuning, - blackjax_mclmc_sampler_params, - _, - ) = blackjax.mclmc_find_L_and_step_size( - mclmc_kernel=kernel, + tuned_state, self.adapt_state, _ = _mclmc_find_l_and_step_size( + mclmc_kernel=kernel_fn, num_steps=num_tuning_steps, state=sampler_state, - rng_key=rng_key_tune, + rng_key=tune_key, diagonal_preconditioning=self._diagonal_preconditioning, - frac_tune3=num_warmup / (3 * num_tuning_steps), - frac_tune2=num_warmup / (3 * num_tuning_steps), frac_tune1=num_warmup / (3 * num_tuning_steps), + frac_tune2=num_warmup / (3 * num_tuning_steps), + frac_tune3=num_warmup / (3 * num_tuning_steps), desired_energy_var=self._desired_energy_var, ) - - self.adapt_state = blackjax_mclmc_sampler_params - + self._kernel = _build_kernel( + logdensity_fn=self.logdensity_fn, + inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, + ) return FullState( - blackjax_state_after_tuning.position, - blackjax_state_after_tuning.momentum, - blackjax_state_after_tuning.logdensity, - blackjax_state_after_tuning.logdensity_grad, + tuned_state.position, + tuned_state.momentum, + tuned_state.logdensity, + tuned_state.logdensity_grad, run_key, ) def sample(self, state, model_args, model_kwargs): - """ - Run MCLMC from the given state and return the resulting state. - - :param state: Current state - :param model_args: Model arguments - :param model_kwargs: Model keyword arguments - :return: Next state after running MCLMC - """ - - mclmc_state = IntegratorState( + mclmc_state = MCLMCState( state.position, state.momentum, state.logdensity, state.logdensity_grad ) - rng_key, rng_key_sample = jax.random.split(state.rng_key, 2) - - kernel = blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=self.logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, - inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, - ) - - new_state, info = kernel( - rng_key=rng_key_sample, + rng_key, sample_key = jax.random.split(state.rng_key, 2) + new_state, _ = self._kernel( + rng_key=sample_key, state=mclmc_state, step_size=self.adapt_state.step_size, L=self.adapt_state.L, ) - return FullState( new_state.position, new_state.momentum, diff --git a/test/infer/test_mclmc.py b/test/infer/test_mclmc.py index b3fd1af9d..f003f9dfc 100644 --- a/test/infer/test_mclmc.py +++ b/test/infer/test_mclmc.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + from numpy.testing import assert_allclose import pytest @@ -13,31 +14,22 @@ from numpyro.infer.mclmc import MCLMC -def test_mclmc_model_required(): - """Test that ValueError is raised when model is None.""" - with pytest.raises(ValueError, match="Model must be specified"): - MCLMC(model=None) - +def _two_dim_model(): + numpyro.sample("x", dist.Normal(jnp.array([0.0, 0.0]), 1.0).to_event(1)) -def test_mclmc_blackjax_not_installed(monkeypatch): - """Test that ImportError is raised with informative message when blackjax is not installed.""" - import numpyro.infer.mclmc as mclmc_module - # Temporarily set _BLACKJAX_AVAILABLE to False - monkeypatch.setattr(mclmc_module, "_BLACKJAX_AVAILABLE", False) +def _model_with_args(loc, scale=1.0): + numpyro.sample("x", dist.Normal(loc, scale).to_event(1)) - def dummy_model(): - numpyro.sample("x", dist.Normal(0, 1)) - with pytest.raises(ImportError, match="MCLMC requires the 'blackjax' package"): - MCLMC(model=dummy_model) +def test_mclmc_model_required(): + """Test that ValueError is raised when model is None.""" + with pytest.raises(ValueError, match="Model must be specified"): + MCLMC(model=None) def test_mclmc_normal(): - """Test MCLMC with a 2D normal distribution. - - Note: MCLMC requires at least 2 dimensions (blackjax limitation). - """ + """Test MCLMC with a 2D normal distribution.""" true_mean = jnp.array([1.0, 2.0]) true_std = jnp.array([0.5, 1.0]) num_warmup, num_samples = 1000, 2000 @@ -95,10 +87,7 @@ def model(): def test_mclmc_logistic_regression(): - """Test MCLMC with a logistic regression model. - - Note: MCLMC currently doesn't pass model_args, so we use a closure pattern. - """ + """Test MCLMC with a logistic regression model.""" N, dim = 1000, 3 num_warmup, num_samples = 1000, 2000 @@ -108,7 +97,7 @@ def test_mclmc_logistic_regression(): logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(key2) - # Use closure pattern since MCLMC doesn't pass model_args + # Closure pattern is used here for compactness. def model(): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) @@ -153,3 +142,61 @@ def model(): assert samples["a"].shape == (num_samples,) assert samples["b"].shape == (num_samples, 3) assert samples["c"].shape == (num_samples, 2, 4) + + +def test_mclmc_model_args_and_kwargs(): + """Test that model_args/model_kwargs are respected during inference.""" + true_mean = jnp.array([1.5, -0.5]) + true_scale = 0.8 + num_warmup, num_samples = 500, 1000 + + kernel = MCLMC(model=_model_with_args) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(1), true_mean, scale=true_scale) + samples = mcmc.get_samples()["x"] + + assert samples.shape == (num_samples, 2) + assert_allclose(jnp.mean(samples, axis=0), true_mean, atol=0.2) + assert_allclose(jnp.std(samples, axis=0), true_scale, atol=0.2) + + +def test_mclmc_rejects_one_dimensional_latent_space(): + """Test that MCLMC rejects models with fewer than 2 latent dimensions.""" + + def one_dim_model(): + numpyro.sample("x", dist.Normal(0.0, 1.0)) + + kernel = MCLMC(model=one_dim_model) + mcmc = MCMC( + kernel, + num_warmup=10, + num_samples=10, + num_chains=1, + progress_bar=False, + ) + with pytest.raises( + ValueError, + match="target distribution must have more than 1 dimension", + ): + mcmc.run(random.PRNGKey(0)) + + +def test_mclmc_small_warmup_runs(): + """Test small warmup edge case where adaptation phases are tiny.""" + kernel = MCLMC(model=_two_dim_model) + mcmc = MCMC( + kernel, + num_warmup=3, + num_samples=20, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(2)) + samples = mcmc.get_samples()["x"] + assert samples.shape == (20, 2) From 8d708d2b8bceae52f5721dda2e6db48694c659ed Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 27 Feb 2026 23:52:57 +0100 Subject: [PATCH 07/10] types --- numpyro/_typing.py | 4 + numpyro/infer/mclmc.py | 324 ++++++++++++++++++++++++++++++----------- pyproject.toml | 1 + 3 files changed, 246 insertions(+), 83 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index d53d60c17..4639fda43 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -35,4 +35,8 @@ """A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays.""" +LogDensityFn: TypeAlias = Callable[[PyTree], NumLike] +"""Callable log-density signature used by gradient-based kernels.""" + + NumLikeT = TypeVar("NumLikeT", bound=NumLike) diff --git a/numpyro/infer/mclmc.py b/numpyro/infer/mclmc.py index 1b0929852..d45908f29 100644 --- a/numpyro/infer/mclmc.py +++ b/numpyro/infer/mclmc.py @@ -1,41 +1,67 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import namedtuple +from collections.abc import Callable +from typing import Any, NamedTuple, cast import jax from jax.flatten_util import ravel_pytree import jax.numpy as jnp +from jax.typing import ArrayLike +from numpyro._typing import LogDensityFn, NumLike, PyTree from numpyro.diagnostics import effective_sample_size from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import initialize_model from numpyro.util import identity -MCLMCState = namedtuple( - "MCLMCState", ["position", "momentum", "logdensity", "logdensity_grad"] -) -MCLMCInfo = namedtuple("MCLMCInfo", ["logdensity", "kinetic_change", "energy_change"]) -MCLMCAdaptationState = namedtuple( - "MCLMCAdaptationState", ["L", "step_size", "inverse_mass_matrix"] -) -FullState = namedtuple( - "FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"] -) + +class MCLMCState(NamedTuple): + position: PyTree + momentum: PyTree + logdensity: NumLike + logdensity_grad: PyTree + + +class MCLMCInfo(NamedTuple): + logdensity: NumLike + kinetic_change: NumLike + energy_change: NumLike + + +class MCLMCAdaptationState(NamedTuple): + L: NumLike + step_size: NumLike + inverse_mass_matrix: ArrayLike + + +class FullState(NamedTuple): + position: PyTree + momentum: PyTree + logdensity: NumLike + logdensity_grad: PyTree + rng_key: jax.dtypes.prng_key + # First momentum-stage coefficient in the 5-stage McLachlan splitting scheme. -_MCLACHLAN_B1 = 0.1931833275037836 +_MCLACHLAN_B1: float = 0.1931833275037836 # Palindromic integrator coefficients for one isokinetic McLachlan update. -_MCLACHLAN_COEFS = (_MCLACHLAN_B1, 0.5, 1 - 2 * _MCLACHLAN_B1, 0.5, _MCLACHLAN_B1) +_MCLACHLAN_COEFS: tuple[float, ...] = ( + _MCLACHLAN_B1, + 0.5, + 1 - 2 * _MCLACHLAN_B1, + 0.5, + _MCLACHLAN_B1, +) # When NaNs are detected during adaptation, shrink step size by this factor. -_DELTA_NAN_STEP_SIZE_FACTOR = 0.8 +_DELTA_NAN_STEP_SIZE_FACTOR: float = 0.8 -def _pytree_size(pytree): +def _pytree_size(pytree: PyTree) -> int: return sum(jnp.size(leaf) for leaf in jax.tree.leaves(pytree)) -def _generate_unit_vector(rng_key, position): +def _generate_unit_vector(rng_key: jax.dtypes.prng_key, position: PyTree) -> PyTree: flat_position, unravel_fn = ravel_pytree(position) sample = jax.random.normal( rng_key, shape=flat_position.shape, dtype=flat_position.dtype @@ -44,8 +70,11 @@ def _generate_unit_vector(rng_key, position): def _incremental_value_update( - expectation, incremental_val, weight=1.0, zero_prevention=0.0 -): + expectation: ArrayLike, + incremental_val: tuple[NumLike, ArrayLike], + weight: NumLike = 1.0, + zero_prevention: NumLike = 0.0, +) -> tuple[NumLike, ArrayLike]: total, average = incremental_val average = jax.tree.map( lambda exp, av: jnp.where( @@ -59,7 +88,9 @@ def _incremental_value_update( return total + weight, average -def _init_mclmc(position, logdensity_fn, rng_key): +def _init_mclmc( + position: PyTree, logdensity_fn: LogDensityFn, rng_key: jax.dtypes.prng_key +) -> MCLMCState: if _pytree_size(position) < 2: raise ValueError( "The target distribution must have more than 1 dimension for MCLMC." @@ -73,7 +104,13 @@ def _init_mclmc(position, logdensity_fn, rng_key): ) -def _position_update(position, kinetic_grad, step_size, coef, logdensity_fn): +def _position_update( + position: PyTree, + kinetic_grad: PyTree, + step_size: NumLike, + coef: NumLike, + logdensity_fn: LogDensityFn, +) -> tuple[PyTree, NumLike, PyTree]: new_position = jax.tree.map( lambda x, grad: x + step_size * coef * grad, position, @@ -83,19 +120,19 @@ def _position_update(position, kinetic_grad, step_size, coef, logdensity_fn): return new_position, logdensity, logdensity_grad -def _normalized_flatten(x, tol=1e-13): +def _normalized_flatten(x: ArrayLike, tol: float = 1e-13) -> tuple[ArrayLike, NumLike]: norm = jnp.linalg.norm(x) return jnp.where(norm > tol, x / norm, x), norm def _esh_dynamics_momentum_update_one_step( - momentum, - logdensity_grad, - step_size, - coef, - inverse_mass_matrix, - previous_kinetic_energy_change=None, -): + momentum: PyTree, + logdensity_grad: PyTree, + step_size: NumLike, + coef: NumLike, + inverse_mass_matrix: ArrayLike, + previous_kinetic_energy_change: NumLike | None = None, +) -> tuple[PyTree, PyTree, NumLike]: sqrt_inverse_mass_matrix = jnp.sqrt(inverse_mass_matrix) flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_grads = flatten_grads * sqrt_inverse_mass_matrix @@ -123,7 +160,12 @@ def _esh_dynamics_momentum_update_one_step( return next_momentum, kinetic_grad, kinetic_energy_change -def _isokinetic_mclachlan_step(state, step_size, logdensity_fn, inverse_mass_matrix): +def _isokinetic_mclachlan_step( + state: MCLMCState, + step_size: NumLike, + logdensity_fn: LogDensityFn, + inverse_mass_matrix: ArrayLike, +) -> tuple[MCLMCState, NumLike]: position, momentum, _, logdensity_grad = state kinetic_energy_change = None @@ -161,7 +203,9 @@ def _isokinetic_mclachlan_step(state, step_size, logdensity_fn, inverse_mass_mat ), kinetic_energy_change -def _partially_refresh_momentum(momentum, rng_key, step_size, L): +def _partially_refresh_momentum( + momentum: PyTree, rng_key: jax.dtypes.prng_key, step_size: NumLike, L: NumLike +) -> PyTree: flat_momentum, unravel_fn = ravel_pytree(momentum) dim = flat_momentum.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) @@ -175,8 +219,13 @@ def _partially_refresh_momentum(momentum, rng_key, step_size, L): def _maruyama_step( - init_state, step_size, L, rng_key, logdensity_fn, inverse_mass_matrix -): + init_state: MCLMCState, + step_size: NumLike, + L: NumLike, + rng_key: jax.dtypes.prng_key, + logdensity_fn: LogDensityFn, + inverse_mass_matrix: ArrayLike, +) -> tuple[MCLMCState, NumLike]: key1, key2 = jax.random.split(rng_key) state = init_state._replace( momentum=_partially_refresh_momentum( @@ -203,7 +252,12 @@ def _maruyama_step( return state, kinetic_change -def _handle_nans(previous_state, next_state, info, key): +def _handle_nans( + previous_state: MCLMCState, + next_state: MCLMCState, + info: MCLMCInfo, + key: jax.dtypes.prng_key, +) -> tuple[MCLMCState, MCLMCInfo]: new_momentum = _generate_unit_vector(key, previous_state.position) flat_position, _ = ravel_pytree(next_state.position) flat_momentum, _ = ravel_pytree(next_state.momentum) @@ -225,8 +279,14 @@ def _handle_nans(previous_state, next_state, info, key): return state, info -def _build_kernel(logdensity_fn, inverse_mass_matrix): - def kernel(rng_key, state, L, step_size): +def _build_kernel( + logdensity_fn: LogDensityFn, inverse_mass_matrix: ArrayLike +) -> Callable[ + [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], tuple[MCLMCState, MCLMCInfo] +]: + def kernel( + rng_key: jax.dtypes.prng_key, state: MCLMCState, L: NumLike, step_size: NumLike + ) -> tuple[MCLMCState, MCLMCInfo]: kernel_key, nan_key = jax.random.split(rng_key) next_state, kinetic_change = _maruyama_step( init_state=state, @@ -253,8 +313,13 @@ def kernel(rng_key, state, L, step_size): def _adaptation_handle_nans( - previous_state, next_state, step_size, step_size_max, kinetic_change, key -): + previous_state: MCLMCState, + next_state: MCLMCState, + step_size: NumLike, + step_size_max: NumLike, + kinetic_change: NumLike, + key: jax.dtypes.prng_key, +) -> tuple[NumLike, MCLMCState, NumLike, NumLike]: flat_position, _ = ravel_pytree(next_state.position) flat_momentum, _ = ravel_pytree(next_state.momentum) nonans = jnp.logical_and( @@ -276,18 +341,34 @@ def _adaptation_handle_nans( def _make_l_step_size_adaptation( - kernel_fn, - dim, - frac_tune1, - frac_tune2, - diagonal_preconditioning, - desired_energy_var=1e-3, - trust_in_estimate=1.5, - num_effective_samples=150, -): + kernel_fn: Callable[ + [ArrayLike], + Callable[ + [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], + tuple[MCLMCState, MCLMCInfo], + ], + ], + dim: int, + frac_tune1: NumLike, + frac_tune2: NumLike, + diagonal_preconditioning: bool, + desired_energy_var: NumLike = 1e-3, + trust_in_estimate: NumLike = 1.5, + num_effective_samples: int = 150, +) -> Callable[ + [MCLMCState, MCLMCAdaptationState, int, jax.dtypes.prng_key], + tuple[MCLMCState, MCLMCAdaptationState], +]: decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) - def predictor(previous_state, params, adaptive_state, rng_key): + def predictor( + previous_state: MCLMCState, + params: MCLMCAdaptationState, + adaptive_state: tuple[NumLike, NumLike, NumLike], + rng_key: jax.dtypes.prng_key, + ) -> tuple[ + MCLMCState, MCLMCAdaptationState, tuple[NumLike, NumLike, NumLike], NumLike + ]: time, x_average, step_size_max = adaptive_state rng_key, nan_key = jax.random.split(rng_key) next_state, info = kernel_fn(params.inverse_mass_matrix)( @@ -315,7 +396,23 @@ def predictor(previous_state, params, adaptive_state, rng_key): params_new = params._replace(step_size=step_size) return state, params_new, (time, x_average, step_size_max), success - def step(iteration_state, weight_and_key): + def step( + iteration_state: tuple[ + MCLMCState, + MCLMCAdaptationState, + tuple[NumLike, NumLike, NumLike], + tuple[NumLike, jax.dtypes.prng_key], + ], + weight_and_key: tuple[NumLike, jax.dtypes.prng_key], + ) -> tuple[ + tuple[ + MCLMCState, + MCLMCAdaptationState, + tuple[NumLike, NumLike, NumLike], + tuple[NumLike, jax.dtypes.prng_key], + ], + None, + ]: mask, rng_key = weight_and_key state, params, adaptive_state, streaming_avg = iteration_state state, params, adaptive_state, success = predictor( @@ -329,7 +426,16 @@ def step(iteration_state, weight_and_key): ) return (state, params, adaptive_state, streaming_avg), None - def run_steps(xs, state, params): + def run_steps( + xs: tuple[ArrayLike, jax.dtypes.prng_key], + state: MCLMCState, + params: MCLMCAdaptationState, + ) -> tuple[ + MCLMCState, + MCLMCAdaptationState, + tuple[NumLike, NumLike, NumLike], + tuple[NumLike, jax.dtypes.prng_key], + ]: return jax.lax.scan( step, init=( @@ -341,7 +447,12 @@ def run_steps(xs, state, params): xs=xs, )[0] - def adaptation(state, params, num_steps, rng_key): + def adaptation( + state: MCLMCState, + params: MCLMCAdaptationState, + num_steps: int, + rng_key: jax.dtypes.prng_key, + ) -> tuple[MCLMCState, MCLMCAdaptationState]: num_steps1 = round(num_steps * frac_tune1) num_steps2 = round(num_steps * frac_tune2) keys = jax.random.split(rng_key, num_steps1 + num_steps2 + 1) @@ -368,11 +479,18 @@ def adaptation(state, params, num_steps, rng_key): def _make_adaptation_l(kernel, frac, lfactor): - def adaptation_l(state, params, num_steps, rng_key): + def adaptation_l( + state: MCLMCState, + params: MCLMCAdaptationState, + num_steps: int, + rng_key: jax.dtypes.prng_key, + ) -> tuple[MCLMCState, MCLMCAdaptationState]: num_steps3 = round(num_steps * frac) keys = jax.random.split(rng_key, num_steps3) - def step(curr_state, key): + def step( + curr_state: MCLMCState, key: jax.dtypes.prng_key + ) -> tuple[MCLMCState, PyTree]: next_state, _ = kernel( rng_key=key, state=curr_state, @@ -392,20 +510,26 @@ def step(curr_state, key): def _mclmc_find_l_and_step_size( - mclmc_kernel, - num_steps, - state, - rng_key, - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.1, - desired_energy_var=5e-4, - trust_in_estimate=1.5, - num_effective_samples=150, - diagonal_preconditioning=True, - params=None, - lfactor=0.4, -): + mclmc_kernel: Callable[ + [ArrayLike], + Callable[ + [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], + tuple[MCLMCState, MCLMCInfo], + ], + ], + num_steps: int, + state: MCLMCState, + rng_key: jax.dtypes.prng_key, + frac_tune1: NumLike = 0.1, + frac_tune2: NumLike = 0.1, + frac_tune3: NumLike = 0.1, + desired_energy_var: NumLike = 5e-4, + trust_in_estimate: NumLike = 1.5, + num_effective_samples: int = 150, + diagonal_preconditioning: bool = True, + params: MCLMCAdaptationState | None = None, + lfactor: NumLike = 0.4, +) -> tuple[MCLMCState, MCLMCAdaptationState, int]: dim = _pytree_size(state.position) if params is None: params = MCLMCAdaptationState( @@ -509,41 +633,61 @@ def model(X, y=None): """ def __init__( - self, - model=None, - desired_energy_var=5e-4, - diagonal_preconditioning=True, - ): + self: "MCLMC", + model: Callable[..., Any] | None = None, + desired_energy_var: NumLike = 5e-4, + diagonal_preconditioning: bool = True, + ) -> None: if model is None: raise ValueError("Model must be specified for MCLMC") self._model = model self._diagonal_preconditioning = diagonal_preconditioning self._desired_energy_var = desired_energy_var - self._postprocess_fn = None + self._postprocess_fn: Callable[..., Callable[[PyTree], PyTree]] | None = None + self.logdensity_fn: LogDensityFn | None = None + self.adapt_state: MCLMCAdaptationState | None = None + self._kernel: ( + Callable[ + [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], + tuple[MCLMCState, MCLMCInfo], + ] + | None + ) = None @property - def model(self): + def model(self: "MCLMC") -> Callable[..., Any]: return self._model @property - def sample_field(self): + def sample_field(self: "MCLMC") -> str: return "position" @property - def default_fields(self): + def default_fields(self: "MCLMC") -> tuple[str, ...]: return (self.sample_field,) - def get_diagnostics_str(self, state): + def get_diagnostics_str(self: "MCLMC", state: FullState) -> str: + if self.adapt_state is None: + return "" return "step_size={:.2e}, L={:.2e}".format( self.adapt_state.step_size, self.adapt_state.L ) - def postprocess_fn(self, args, kwargs): + def postprocess_fn( + self: "MCLMC", args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Callable[[PyTree], PyTree]: if self._postprocess_fn is None: - return identity + return cast(Callable[[PyTree], PyTree], identity) return self._postprocess_fn(*args, **kwargs) - def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): + def init( + self: "MCLMC", + rng_key: jax.dtypes.prng_key, + num_warmup: int, + init_params: Any, + model_args: tuple[Any, ...], + model_kwargs: dict[str, Any], + ) -> FullState: init_model_key, init_state_key, run_key, tune_key = jax.random.split(rng_key, 4) init_params, potential_fn_gen, postprocess_fn, _ = initialize_model( init_model_key, @@ -554,7 +698,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): ) self._postprocess_fn = postprocess_fn - def logdensity_fn(position): + def logdensity_fn(position: PyTree) -> NumLike: return -potential_fn_gen(*model_args, **model_kwargs)(position) self.logdensity_fn = logdensity_fn @@ -564,7 +708,12 @@ def logdensity_fn(position): rng_key=init_state_key, ) - def kernel_fn(inverse_mass_matrix): + def kernel_fn( + inverse_mass_matrix: ArrayLike, + ) -> Callable[ + [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], + tuple[MCLMCState, MCLMCInfo], + ]: return _build_kernel( logdensity_fn=self.logdensity_fn, inverse_mass_matrix=inverse_mass_matrix, @@ -594,11 +743,20 @@ def kernel_fn(inverse_mass_matrix): run_key, ) - def sample(self, state, model_args, model_kwargs): + def sample( + self: "MCLMC", + state: FullState, + model_args: tuple[Any, ...], + model_kwargs: dict[str, Any], + ) -> FullState: + del model_args, model_kwargs mclmc_state = MCLMCState( state.position, state.momentum, state.logdensity, state.logdensity_grad ) rng_key, sample_key = jax.random.split(state.rng_key, 2) + if self._kernel is None or self.adapt_state is None: + msg = "MCLMC must be initialized before calling sample." + raise RuntimeError(msg) new_state, _ = self._kernel( rng_key=sample_key, state=mclmc_state, @@ -613,7 +771,7 @@ def sample(self, state, model_args, model_kwargs): rng_key, ) - def __getstate__(self): + def __getstate__(self: "MCLMC") -> dict[str, Any]: state = self.__dict__.copy() state["_postprocess_fn"] = None return state diff --git a/pyproject.toml b/pyproject.toml index 46486b1ac..b3cb374f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,7 @@ module = [ "numpyro.diagnostics.*", "numpyro.handlers.*", "numpyro.infer.elbo.*", + "numpyro.infer.mclmc.py", "numpyro.optim.*", "numpyro.primitives.*", "numpyro.patch.*", From 621a95c05e66c4ec87f2f16e25ce2f7c22acfa64 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 27 Feb 2026 23:58:08 +0100 Subject: [PATCH 08/10] docstrings --- numpyro/infer/mclmc.py | 148 +++++++++++++++++++++++++++++------------ 1 file changed, 105 insertions(+), 43 deletions(-) diff --git a/numpyro/infer/mclmc.py b/numpyro/infer/mclmc.py index d45908f29..806ca7eaf 100644 --- a/numpyro/infer/mclmc.py +++ b/numpyro/infer/mclmc.py @@ -43,6 +43,12 @@ class FullState(NamedTuple): rng_key: jax.dtypes.prng_key +KernelFn = Callable[ + [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], tuple[MCLMCState, MCLMCInfo] +] +KernelFactoryFn = Callable[[ArrayLike], KernelFn] + + # First momentum-stage coefficient in the 5-stage McLachlan splitting scheme. _MCLACHLAN_B1: float = 0.1931833275037836 # Palindromic integrator coefficients for one isokinetic McLachlan update. @@ -76,12 +82,14 @@ def _incremental_value_update( zero_prevention: NumLike = 0.0, ) -> tuple[NumLike, ArrayLike]: total, average = incremental_val + total_weight = total + weight + zero_prevention + + def _update_average(exp: ArrayLike, av: ArrayLike) -> ArrayLike: + numerator = total * av + weight * exp + return jnp.where(numerator == 0.0, 0.0, numerator / total_weight) + average = jax.tree.map( - lambda exp, av: jnp.where( - (total * av + weight * exp) == 0.0, - 0.0, - (total * av + weight * exp) / (total + weight + zero_prevention), - ), + _update_average, expectation, average, ) @@ -252,6 +260,14 @@ def _maruyama_step( return state, kinetic_change +def _state_is_finite(state: MCLMCState) -> NumLike: + flat_position, _ = ravel_pytree(state.position) + flat_momentum, _ = ravel_pytree(state.momentum) + return jnp.logical_and( + jnp.all(jnp.isfinite(flat_position)), jnp.all(jnp.isfinite(flat_momentum)) + ) + + def _handle_nans( previous_state: MCLMCState, next_state: MCLMCState, @@ -259,11 +275,7 @@ def _handle_nans( key: jax.dtypes.prng_key, ) -> tuple[MCLMCState, MCLMCInfo]: new_momentum = _generate_unit_vector(key, previous_state.position) - flat_position, _ = ravel_pytree(next_state.position) - flat_momentum, _ = ravel_pytree(next_state.momentum) - nonans = jnp.logical_and( - jnp.all(jnp.isfinite(flat_position)), jnp.all(jnp.isfinite(flat_momentum)) - ) + nonans = _state_is_finite(next_state) state, info = jax.lax.cond( nonans, lambda: (next_state, info), @@ -281,9 +293,7 @@ def _handle_nans( def _build_kernel( logdensity_fn: LogDensityFn, inverse_mass_matrix: ArrayLike -) -> Callable[ - [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], tuple[MCLMCState, MCLMCInfo] -]: +) -> KernelFn: def kernel( rng_key: jax.dtypes.prng_key, state: MCLMCState, L: NumLike, step_size: NumLike ) -> tuple[MCLMCState, MCLMCInfo]: @@ -320,11 +330,7 @@ def _adaptation_handle_nans( kinetic_change: NumLike, key: jax.dtypes.prng_key, ) -> tuple[NumLike, MCLMCState, NumLike, NumLike]: - flat_position, _ = ravel_pytree(next_state.position) - flat_momentum, _ = ravel_pytree(next_state.momentum) - nonans = jnp.logical_and( - jnp.all(jnp.isfinite(flat_position)), jnp.all(jnp.isfinite(flat_momentum)) - ) + nonans = _state_is_finite(next_state) state, step_size, kinetic_change = jax.tree.map( lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (next_state, step_size_max, kinetic_change), @@ -341,13 +347,7 @@ def _adaptation_handle_nans( def _make_l_step_size_adaptation( - kernel_fn: Callable[ - [ArrayLike], - Callable[ - [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], - tuple[MCLMCState, MCLMCInfo], - ], - ], + kernel_fn: KernelFactoryFn, dim: int, frac_tune1: NumLike, frac_tune2: NumLike, @@ -510,13 +510,7 @@ def step( def _mclmc_find_l_and_step_size( - mclmc_kernel: Callable[ - [ArrayLike], - Callable[ - [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], - tuple[MCLMCState, MCLMCInfo], - ], - ], + mclmc_kernel: KernelFactoryFn, num_steps: int, state: MCLMCState, rng_key: jax.dtypes.prng_key, @@ -638,6 +632,19 @@ def __init__( desired_energy_var: NumLike = 5e-4, diagonal_preconditioning: bool = True, ) -> None: + """ + Construct an MCLMC kernel. + + :param model: NumPyro model callable that defines latent variables and + observations. + :param desired_energy_var: Target energy variance used during warmup to + tune step size. Smaller values typically produce more conservative + integrator updates. + :param diagonal_preconditioning: Whether to estimate a diagonal inverse + mass matrix during warmup. If ``False``, adaptation uses isotropic + scaling. + :raises ValueError: If ``model`` is not provided. + """ if model is None: raise ValueError("Model must be specified for MCLMC") self._model = model @@ -646,27 +653,41 @@ def __init__( self._postprocess_fn: Callable[..., Callable[[PyTree], PyTree]] | None = None self.logdensity_fn: LogDensityFn | None = None self.adapt_state: MCLMCAdaptationState | None = None - self._kernel: ( - Callable[ - [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], - tuple[MCLMCState, MCLMCInfo], - ] - | None - ) = None + self._kernel: KernelFn | None = None @property def model(self: "MCLMC") -> Callable[..., Any]: + """Return the model callable associated with this kernel.""" return self._model @property def sample_field(self: "MCLMC") -> str: + """ + Name of the state attribute treated as the MCMC sample. + + This is used by :class:`~numpyro.infer.mcmc.MCMC` for collection and + postprocessing. + """ return "position" @property def default_fields(self: "MCLMC") -> tuple[str, ...]: + """ + State attributes collected by default during sampling. + + :return: Tuple of field names to collect from each state. + """ return (self.sample_field,) def get_diagnostics_str(self: "MCLMC", state: FullState) -> str: + """ + Return progress-bar diagnostics for current adaptation parameters. + + :param state: Current full sampler state (unused; present for kernel API + compatibility). + :return: A formatted diagnostics string during/after initialization, or + an empty string if adaptation is unavailable. + """ if self.adapt_state is None: return "" return "step_size={:.2e}, L={:.2e}".format( @@ -676,6 +697,14 @@ def get_diagnostics_str(self: "MCLMC", state: FullState) -> str: def postprocess_fn( self: "MCLMC", args: tuple[Any, ...], kwargs: dict[str, Any] ) -> Callable[[PyTree], PyTree]: + """ + Build a transform from unconstrained latent space to constrained space. + + :param args: Positional model arguments used to initialize transforms. + :param kwargs: Keyword model arguments used to initialize transforms. + :return: Callable that maps unconstrained latent samples to constrained + values and includes deterministic sites. + """ if self._postprocess_fn is None: return cast(Callable[[PyTree], PyTree], identity) return self._postprocess_fn(*args, **kwargs) @@ -688,6 +717,23 @@ def init( model_args: tuple[Any, ...], model_kwargs: dict[str, Any], ) -> FullState: + """ + Initialize sampler state and run warmup adaptation. + + This method initializes model state, builds the log-density function, + adapts ``step_size``, ``L``, and (optionally) diagonal preconditioning, + then returns a ready-to-sample state. + + :param rng_key: JAX PRNG key. + :param num_warmup: Number of warmup steps requested by the outer MCMC + driver; used to set adaptation phase fractions. + :param init_params: Optional initial parameters (kept for kernel API + compatibility; model initialization is delegated to + :func:`~numpyro.infer.util.initialize_model`). + :param model_args: Positional arguments passed to the model. + :param model_kwargs: Keyword arguments passed to the model. + :return: Fully initialized :class:`FullState`. + """ init_model_key, init_state_key, run_key, tune_key = jax.random.split(rng_key, 4) init_params, potential_fn_gen, postprocess_fn, _ = initialize_model( init_model_key, @@ -710,10 +756,7 @@ def logdensity_fn(position: PyTree) -> NumLike: def kernel_fn( inverse_mass_matrix: ArrayLike, - ) -> Callable[ - [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], - tuple[MCLMCState, MCLMCInfo], - ]: + ) -> KernelFn: return _build_kernel( logdensity_fn=self.logdensity_fn, inverse_mass_matrix=inverse_mass_matrix, @@ -749,6 +792,17 @@ def sample( model_args: tuple[Any, ...], model_kwargs: dict[str, Any], ) -> FullState: + """ + Advance the Markov chain by one MCLMC transition. + + :param state: Current full sampler state. + :param model_args: Unused after initialization (kept for API + compatibility with :class:`~numpyro.infer.mcmc.MCMCKernel`). + :param model_kwargs: Unused after initialization (kept for API + compatibility with :class:`~numpyro.infer.mcmc.MCMCKernel`). + :return: Next :class:`FullState` after one transition. + :raises RuntimeError: If called before :meth:`init`. + """ del model_args, model_kwargs mclmc_state = MCLMCState( state.position, state.momentum, state.logdensity, state.logdensity_grad @@ -772,6 +826,14 @@ def sample( ) def __getstate__(self: "MCLMC") -> dict[str, Any]: + """ + Return a pickle-safe object state. + + The cached postprocess closure is intentionally cleared because closures + from ``initialize_model`` are not reliably serializable. + + :return: Serializable state dictionary for this kernel instance. + """ state = self.__dict__.copy() state["_postprocess_fn"] = None return state From d27da771d97f13459f3f8a26d0217f01aee43f65 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Sat, 28 Feb 2026 11:37:19 +0100 Subject: [PATCH 09/10] cleanup --- numpyro/infer/mclmc.py | 100 ++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 40 deletions(-) diff --git a/numpyro/infer/mclmc.py b/numpyro/infer/mclmc.py index 806ca7eaf..9cd0ad340 100644 --- a/numpyro/infer/mclmc.py +++ b/numpyro/infer/mclmc.py @@ -43,6 +43,24 @@ class FullState(NamedTuple): rng_key: jax.dtypes.prng_key +class _AdaptationAverages(NamedTuple): + time: NumLike + x_average: NumLike + step_size_max: NumLike + + +class _StreamingAverage(NamedTuple): + total: NumLike + average: ArrayLike + + +class _AdaptationIterationState(NamedTuple): + state: MCLMCState + params: MCLMCAdaptationState + adaptive_state: _AdaptationAverages + streaming_avg: _StreamingAverage + + KernelFn = Callable[ [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], tuple[MCLMCState, MCLMCInfo] ] @@ -77,10 +95,10 @@ def _generate_unit_vector(rng_key: jax.dtypes.prng_key, position: PyTree) -> PyT def _incremental_value_update( expectation: ArrayLike, - incremental_val: tuple[NumLike, ArrayLike], + incremental_val: _StreamingAverage, weight: NumLike = 1.0, zero_prevention: NumLike = 0.0, -) -> tuple[NumLike, ArrayLike]: +) -> _StreamingAverage: total, average = incremental_val total_weight = total + weight + zero_prevention @@ -93,7 +111,7 @@ def _update_average(exp: ArrayLike, av: ArrayLike) -> ArrayLike: expectation, average, ) - return total + weight, average + return _StreamingAverage(total + weight, average) def _init_mclmc( @@ -268,19 +286,26 @@ def _state_is_finite(state: MCLMCState) -> NumLike: ) +def _fallback_state_with_fresh_momentum( + previous_state: MCLMCState, key: jax.dtypes.prng_key +) -> MCLMCState: + return previous_state._replace( + momentum=_generate_unit_vector(key, previous_state.position) + ) + + def _handle_nans( previous_state: MCLMCState, next_state: MCLMCState, info: MCLMCInfo, key: jax.dtypes.prng_key, ) -> tuple[MCLMCState, MCLMCInfo]: - new_momentum = _generate_unit_vector(key, previous_state.position) nonans = _state_is_finite(next_state) state, info = jax.lax.cond( nonans, lambda: (next_state, info), lambda: ( - previous_state._replace(momentum=new_momentum), + _fallback_state_with_fresh_momentum(previous_state, key), MCLMCInfo( logdensity=previous_state.logdensity, energy_change=jnp.zeros_like(info.energy_change), @@ -338,9 +363,7 @@ def _adaptation_handle_nans( ) state = jax.lax.cond( jnp.isnan(next_state.logdensity), - lambda: state._replace( - momentum=_generate_unit_vector(key, previous_state.position) - ), + lambda: _fallback_state_with_fresh_momentum(state, key), lambda: state, ) return nonans, state, step_size, kinetic_change @@ -364,11 +387,9 @@ def _make_l_step_size_adaptation( def predictor( previous_state: MCLMCState, params: MCLMCAdaptationState, - adaptive_state: tuple[NumLike, NumLike, NumLike], + adaptive_state: _AdaptationAverages, rng_key: jax.dtypes.prng_key, - ) -> tuple[ - MCLMCState, MCLMCAdaptationState, tuple[NumLike, NumLike, NumLike], NumLike - ]: + ) -> tuple[MCLMCState, MCLMCAdaptationState, _AdaptationAverages, NumLike]: time, x_average, step_size_max = adaptive_state rng_key, nan_key = jax.random.split(rng_key) next_state, info = kernel_fn(params.inverse_mass_matrix)( @@ -394,25 +415,17 @@ def predictor( step_size = jnp.power(x_average / time, -1.0 / 6.0) step_size = jnp.where(step_size < step_size_max, step_size, step_size_max) params_new = params._replace(step_size=step_size) - return state, params_new, (time, x_average, step_size_max), success + return ( + state, + params_new, + _AdaptationAverages(time, x_average, step_size_max), + success, + ) def step( - iteration_state: tuple[ - MCLMCState, - MCLMCAdaptationState, - tuple[NumLike, NumLike, NumLike], - tuple[NumLike, jax.dtypes.prng_key], - ], + iteration_state: _AdaptationIterationState, weight_and_key: tuple[NumLike, jax.dtypes.prng_key], - ) -> tuple[ - tuple[ - MCLMCState, - MCLMCAdaptationState, - tuple[NumLike, NumLike, NumLike], - tuple[NumLike, jax.dtypes.prng_key], - ], - None, - ]: + ) -> tuple[_AdaptationIterationState, None]: mask, rng_key = weight_and_key state, params, adaptive_state, streaming_avg = iteration_state state, params, adaptive_state, success = predictor( @@ -424,25 +437,22 @@ def step( incremental_val=streaming_avg, weight=mask * success * params.step_size, ) - return (state, params, adaptive_state, streaming_avg), None + return _AdaptationIterationState( + state, params, adaptive_state, streaming_avg + ), None def run_steps( xs: tuple[ArrayLike, jax.dtypes.prng_key], state: MCLMCState, params: MCLMCAdaptationState, - ) -> tuple[ - MCLMCState, - MCLMCAdaptationState, - tuple[NumLike, NumLike, NumLike], - tuple[NumLike, jax.dtypes.prng_key], - ]: + ) -> _AdaptationIterationState: return jax.lax.scan( step, - init=( + init=_AdaptationIterationState( state, params, - (0.0, 0.0, jnp.inf), - (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + _AdaptationAverages(0.0, 0.0, jnp.inf), + _StreamingAverage(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), ), xs=xs, )[0] @@ -459,7 +469,9 @@ def adaptation( tune_keys, final_key = keys[:-1], keys[-1] mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - state, params, _, (_, average) = run_steps((mask, tune_keys), state, params) + iteration_state = run_steps((mask, tune_keys), state, params) + state, params = iteration_state.state, iteration_state.params + average = iteration_state.streaming_avg.average L = params.L inverse_mass_matrix = params.inverse_mass_matrix if num_steps2 > 1: @@ -472,7 +484,8 @@ def adaptation( L = jnp.sqrt(dim) steps = round(num_steps2 / 3) keys = jax.random.split(final_key, steps) - state, params, _, (_, _) = run_steps((jnp.ones(steps), keys), state, params) + iteration_state = run_steps((jnp.ones(steps), keys), state, params) + state, params = iteration_state.state, iteration_state.params return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) return adaptation @@ -502,6 +515,13 @@ def step( state, samples = jax.lax.scan(step, init=state, xs=keys) flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) ess = effective_sample_size(flat_samples[None, ...]) + ess = jnp.nan_to_num( + ess, + nan=1.0, + posinf=float(num_steps3), + neginf=1.0, + ) + ess = jnp.clip(ess, min=1.0) return state, params._replace( L=lfactor * params.step_size * jnp.mean(num_steps3 / ess) ) From e2d754c1e71c8c92833539218105d13e6d576ddb Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Sat, 28 Feb 2026 19:35:38 +0100 Subject: [PATCH 10/10] tests --- test/infer/test_mclmc.py | 457 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 457 insertions(+) diff --git a/test/infer/test_mclmc.py b/test/infer/test_mclmc.py index f003f9dfc..68e231a2d 100644 --- a/test/infer/test_mclmc.py +++ b/test/infer/test_mclmc.py @@ -5,12 +5,14 @@ from numpy.testing import assert_allclose import pytest +import jax from jax import random import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC +import numpyro.infer.mclmc as mclmc_module from numpyro.infer.mclmc import MCLMC @@ -22,6 +24,348 @@ def _model_with_args(loc, scale=1.0): numpyro.sample("x", dist.Normal(loc, scale).to_event(1)) +def _gaussian_logdensity(x): + return -0.5 * jnp.sum(jnp.square(x)) + + +def _make_test_state(key=None): + if key is None: + key = random.PRNGKey(0) + return mclmc_module._init_mclmc( + position=jnp.array([0.3, -0.7]), + logdensity_fn=_gaussian_logdensity, + rng_key=key, + ) + + +def test_pytree_size_counts_all_leaves(): + pytree = {"a": jnp.zeros((2, 3)), "b": [jnp.ones((4,)), jnp.ones(())]} + assert mclmc_module._pytree_size(pytree) == 11 + + +def test_generate_unit_vector_has_unit_norm(): + position = jnp.array([1.0, 2.0, 3.0]) + vec = mclmc_module._generate_unit_vector(random.PRNGKey(0), position) + flat_vec, _ = jax.flatten_util.ravel_pytree(vec) + assert_allclose(jnp.linalg.norm(flat_vec), 1.0, atol=1e-6) + assert flat_vec.shape == position.shape + + +def test_incremental_value_update_weighted_average(): + avg = mclmc_module._StreamingAverage( + total=jnp.array(0.0), average=jnp.array([0.0, 0.0]) + ) + avg = mclmc_module._incremental_value_update( + expectation=jnp.array([2.0, 4.0]), incremental_val=avg, weight=2.0 + ) + avg = mclmc_module._incremental_value_update( + expectation=jnp.array([4.0, 8.0]), incremental_val=avg, weight=2.0 + ) + assert_allclose(avg.average, jnp.array([3.0, 6.0]), atol=1e-6) + assert_allclose(avg.total, 4.0, atol=1e-6) + + +def test_incremental_value_update_zero_numerator_safe(): + avg = mclmc_module._StreamingAverage(total=jnp.array(0.0), average=jnp.array([0.0])) + updated = mclmc_module._incremental_value_update( + expectation=jnp.array([0.0]), incremental_val=avg, weight=0.0 + ) + assert_allclose(updated.average, jnp.array([0.0])) + + +def test_init_mclmc_rejects_low_dimension(): + with pytest.raises( + ValueError, match="target distribution must have more than 1 dimension" + ): + mclmc_module._init_mclmc( + position=jnp.array([0.0]), + logdensity_fn=lambda x: -0.5 * jnp.sum(x**2), + rng_key=random.PRNGKey(0), + ) + + +def test_init_mclmc_returns_valid_state(): + state = _make_test_state(random.PRNGKey(1)) + flat_momentum, _ = jax.flatten_util.ravel_pytree(state.momentum) + assert jnp.isfinite(state.logdensity) + assert jnp.all(jnp.isfinite(state.logdensity_grad)) + assert_allclose(jnp.linalg.norm(flat_momentum), 1.0, atol=1e-6) + + +def test_position_update_matches_expected_gaussian_update(): + position = jnp.array([1.0, 2.0]) + kinetic_grad = jnp.array([0.5, -1.0]) + new_position, logdensity, grad = mclmc_module._position_update( + position=position, + kinetic_grad=kinetic_grad, + step_size=0.1, + coef=0.5, + logdensity_fn=_gaussian_logdensity, + ) + expected_position = jnp.array([1.025, 1.95]) + assert_allclose(new_position, expected_position, atol=1e-7) + assert_allclose(logdensity, _gaussian_logdensity(expected_position), atol=1e-7) + assert_allclose(grad, -expected_position, atol=1e-7) + + +def test_normalized_flatten_for_nonzero_and_zero_vectors(): + normalized, norm = mclmc_module._normalized_flatten(jnp.array([3.0, 4.0])) + assert_allclose(normalized, jnp.array([0.6, 0.8]), atol=1e-7) + assert_allclose(norm, 5.0, atol=1e-7) + + normalized_zero, norm_zero = mclmc_module._normalized_flatten(jnp.zeros((3,))) + assert_allclose(normalized_zero, jnp.zeros((3,)), atol=1e-7) + assert_allclose(norm_zero, 0.0, atol=1e-7) + + +def test_esh_dynamics_momentum_update_matches_naive_formula(): + step_size = 1e-3 + key0, key1 = random.split(random.PRNGKey(62)) + gradient = random.uniform(key0, shape=(3,)) + momentum = random.uniform(key1, shape=(3,)) + momentum = momentum / jnp.linalg.norm(momentum) + + gradient_norm = jnp.linalg.norm(gradient) + gradient_normalized = gradient / gradient_norm + delta = step_size * gradient_norm / (momentum.shape[0] - 1) + naive_next = ( + momentum + + gradient_normalized + * ( + jnp.sinh(delta) + + jnp.dot(gradient_normalized, momentum * (jnp.cosh(delta) - 1)) + ) + ) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta))) + naive_next = naive_next / jnp.linalg.norm(naive_next) + + next_momentum, _, _ = mclmc_module._esh_dynamics_momentum_update_one_step( + momentum=momentum, + logdensity_grad=gradient, + step_size=step_size, + coef=1.0, + inverse_mass_matrix=jnp.ones((3,)), + ) + assert_allclose(next_momentum, naive_next, atol=1e-6) + + +def test_isokinetic_mclachlan_step_returns_finite_state_and_unit_momentum(): + state = _make_test_state(random.PRNGKey(0)) + next_state, kinetic_change = mclmc_module._isokinetic_mclachlan_step( + state=state, + step_size=1e-3, + logdensity_fn=_gaussian_logdensity, + inverse_mass_matrix=jnp.ones((2,)), + ) + flat_momentum, _ = jax.flatten_util.ravel_pytree(next_state.momentum) + assert jnp.isfinite(kinetic_change) + assert jnp.isfinite(next_state.logdensity) + assert jnp.all(jnp.isfinite(next_state.logdensity_grad)) + assert_allclose(jnp.linalg.norm(flat_momentum), 1.0, atol=1e-5) + + +def test_partially_refresh_momentum_respects_infinite_l(): + momentum = jnp.array([1.0, 0.0]) + refreshed_inf = mclmc_module._partially_refresh_momentum( + momentum=momentum, + rng_key=random.PRNGKey(0), + step_size=0.1, + L=jnp.inf, + ) + assert_allclose(refreshed_inf, momentum) + + refreshed = mclmc_module._partially_refresh_momentum( + momentum=momentum, + rng_key=random.PRNGKey(0), + step_size=0.1, + L=1.0, + ) + assert_allclose(jnp.linalg.norm(refreshed), 1.0, atol=1e-6) + + +def test_maruyama_step_returns_finite_values(): + state = _make_test_state(random.PRNGKey(0)) + next_state, kinetic_change = mclmc_module._maruyama_step( + init_state=state, + step_size=1e-2, + L=1.0, + rng_key=random.PRNGKey(1), + logdensity_fn=_gaussian_logdensity, + inverse_mass_matrix=jnp.ones((2,)), + ) + assert jnp.isfinite(kinetic_change) + assert jnp.isfinite(next_state.logdensity) + assert jnp.all(jnp.isfinite(next_state.logdensity_grad)) + assert mclmc_module._state_is_finite(next_state) + + +def test_state_is_finite_detects_nan_and_inf(): + state = _make_test_state(random.PRNGKey(0)) + assert mclmc_module._state_is_finite(state) + + nan_state = state._replace(position=jnp.array([jnp.nan, 0.0])) + inf_state = state._replace(momentum=jnp.array([jnp.inf, 0.0])) + assert not mclmc_module._state_is_finite(nan_state) + assert not mclmc_module._state_is_finite(inf_state) + + +def test_fallback_state_with_fresh_momentum_preserves_position_and_logdensity(): + state = _make_test_state(random.PRNGKey(0)) + new_state = mclmc_module._fallback_state_with_fresh_momentum( + previous_state=state, key=random.PRNGKey(2) + ) + assert_allclose(new_state.position, state.position) + assert_allclose(new_state.logdensity, state.logdensity) + assert_allclose(jnp.linalg.norm(new_state.momentum), 1.0, atol=1e-6) + + +def test_handle_nans_keeps_valid_state_and_falls_back_for_invalid_state(): + previous = _make_test_state(random.PRNGKey(0)) + valid_next = previous._replace(position=previous.position + 0.1) + info = mclmc_module.MCLMCInfo( + logdensity=valid_next.logdensity, + kinetic_change=jnp.array(0.3), + energy_change=jnp.array(0.2), + ) + state_ok, info_ok = mclmc_module._handle_nans( + previous_state=previous, next_state=valid_next, info=info, key=random.PRNGKey(1) + ) + assert_allclose(state_ok.position, valid_next.position) + assert_allclose(info_ok.energy_change, info.energy_change) + + invalid_next = valid_next._replace(position=jnp.array([jnp.nan, 0.0])) + state_bad, info_bad = mclmc_module._handle_nans( + previous_state=previous, + next_state=invalid_next, + info=info, + key=random.PRNGKey(3), + ) + assert_allclose(state_bad.position, previous.position) + assert_allclose(info_bad.logdensity, previous.logdensity) + assert_allclose(info_bad.energy_change, 0.0) + assert_allclose(info_bad.kinetic_change, 0.0) + + +def test_build_kernel_single_step_outputs_finite_state_and_info(): + kernel = mclmc_module._build_kernel( + logdensity_fn=_gaussian_logdensity, inverse_mass_matrix=jnp.ones((2,)) + ) + state = _make_test_state(random.PRNGKey(0)) + next_state, info = kernel( + rng_key=random.PRNGKey(1), state=state, L=1.0, step_size=1e-2 + ) + assert mclmc_module._state_is_finite(next_state) + assert jnp.isfinite(info.logdensity) + assert jnp.isfinite(info.energy_change) + assert jnp.isfinite(info.kinetic_change) + + +def test_adaptation_handle_nans_behavior(): + previous = _make_test_state(random.PRNGKey(0)) + next_state = previous._replace(position=previous.position + 0.1) + success, state, new_step_size_max, new_kinetic = ( + mclmc_module._adaptation_handle_nans( + previous_state=previous, + next_state=next_state, + step_size=jnp.array(0.2), + step_size_max=jnp.array(0.5), + kinetic_change=jnp.array(0.1), + key=random.PRNGKey(2), + ) + ) + assert success + assert_allclose(state.position, next_state.position) + assert_allclose(new_step_size_max, 0.5) + assert_allclose(new_kinetic, 0.1) + + invalid = next_state._replace( + position=jnp.array([jnp.nan, 0.0]), logdensity=jnp.nan + ) + success, state, new_step_size_max, new_kinetic = ( + mclmc_module._adaptation_handle_nans( + previous_state=previous, + next_state=invalid, + step_size=jnp.array(0.2), + step_size_max=jnp.array(0.5), + kinetic_change=jnp.array(0.1), + key=random.PRNGKey(3), + ) + ) + assert not success + assert_allclose(new_step_size_max, 0.2 * mclmc_module._DELTA_NAN_STEP_SIZE_FACTOR) + assert_allclose(new_kinetic, 0.0) + assert_allclose(state.position, previous.position) + + +def test_make_l_step_size_adaptation_returns_finite_positive_params(): + dim = 2 + initial_state = _make_test_state(random.PRNGKey(0)) + params = mclmc_module.MCLMCAdaptationState( + L=jnp.sqrt(dim), + step_size=0.2, + inverse_mass_matrix=jnp.ones((dim,)), + ) + adaptation = mclmc_module._make_l_step_size_adaptation( + kernel_fn=lambda imm: mclmc_module._build_kernel(_gaussian_logdensity, imm), + dim=dim, + frac_tune1=0.2, + frac_tune2=0.2, + diagonal_preconditioning=True, + ) + state, new_params = adaptation( + initial_state, + params, + num_steps=30, + rng_key=random.PRNGKey(1), + ) + assert mclmc_module._state_is_finite(state) + assert jnp.isfinite(new_params.L) and (new_params.L > 0) + assert jnp.isfinite(new_params.step_size) and (new_params.step_size > 0) + assert new_params.inverse_mass_matrix.shape == (dim,) + + +def test_make_adaptation_l_nominal_case_updates_l(): + state = _make_test_state(random.PRNGKey(0)) + params = mclmc_module.MCLMCAdaptationState( + L=jnp.array(1.0), + step_size=jnp.array(0.1), + inverse_mass_matrix=jnp.ones((2,)), + ) + kernel = mclmc_module._build_kernel(_gaussian_logdensity, jnp.ones((2,))) + adaptation_l = mclmc_module._make_adaptation_l(kernel=kernel, frac=0.5, lfactor=0.4) + _, new_params = adaptation_l( + state=state, + params=params, + num_steps=12, + rng_key=random.PRNGKey(2), + ) + assert jnp.isfinite(new_params.L) + assert new_params.L > 0 + + +def test_mclmc_find_l_and_step_size_returns_expected_phase_accounting(): + state = _make_test_state(random.PRNGKey(0)) + state, params, total_steps = mclmc_module._mclmc_find_l_and_step_size( + mclmc_kernel=lambda imm: mclmc_module._build_kernel(_gaussian_logdensity, imm), + num_steps=20, + state=state, + rng_key=random.PRNGKey(1), + frac_tune1=0.2, + frac_tune2=0.2, + frac_tune3=0.2, + diagonal_preconditioning=True, + ) + expected_num_steps1 = round(20 * 0.2) + expected_num_steps2 = round(20 * 0.2) + (round(20 * 0.2) // 3) + expected_num_steps3 = round(20 * 0.2) + assert ( + total_steps == expected_num_steps1 + expected_num_steps2 + expected_num_steps3 + ) + assert mclmc_module._state_is_finite(state) + assert jnp.isfinite(params.L) and (params.L > 0) + assert jnp.isfinite(params.step_size) and (params.step_size > 0) + assert params.inverse_mass_matrix.shape == (2,) + + def test_mclmc_model_required(): """Test that ValueError is raised when model is None.""" with pytest.raises(ValueError, match="Model must be specified"): @@ -200,3 +544,116 @@ def test_mclmc_small_warmup_runs(): mcmc.run(random.PRNGKey(2)) samples = mcmc.get_samples()["x"] assert samples.shape == (20, 2) + + +def test_mclmc_public_properties_and_diagnostics(): + kernel = MCLMC(model=_two_dim_model) + assert kernel.model is _two_dim_model + assert kernel.sample_field == "position" + assert kernel.default_fields == ("position",) + assert kernel.get_diagnostics_str(None) == "" + kernel.adapt_state = mclmc_module.MCLMCAdaptationState( + L=jnp.array(1.2), step_size=jnp.array(0.05), inverse_mass_matrix=jnp.ones((2,)) + ) + assert "step_size=" in kernel.get_diagnostics_str(None) + assert "L=" in kernel.get_diagnostics_str(None) + + +def test_mclmc_postprocess_fn_identity_when_uninitialized(): + kernel = MCLMC(model=_two_dim_model) + fn = kernel.postprocess_fn((), {}) + x = {"z": jnp.array([1.0, 2.0])} + out = fn(x) + assert out is x + + +def test_mclmc_sample_raises_if_not_initialized(): + kernel = MCLMC(model=_two_dim_model) + state = mclmc_module.FullState( + position=jnp.array([0.0, 0.0]), + momentum=jnp.array([1.0, 0.0]), + logdensity=jnp.array(0.0), + logdensity_grad=jnp.array([0.0, 0.0]), + rng_key=random.PRNGKey(0), + ) + with pytest.raises(RuntimeError, match="must be initialized"): + kernel.sample(state, (), {}) + + +def test_mclmc_init_and_sample_direct_api(): + kernel = MCLMC(model=_two_dim_model) + state = kernel.init( + rng_key=random.PRNGKey(0), + num_warmup=30, + init_params=None, + model_args=(), + model_kwargs={}, + ) + assert isinstance(state, mclmc_module.FullState) + next_state = kernel.sample(state, (), {}) + assert isinstance(next_state, mclmc_module.FullState) + assert next_state.position["x"].shape == (2,) + + +def test_mclmc_postprocess_fn_after_init_returns_callable(): + kernel = MCLMC(model=_two_dim_model) + kernel.init( + rng_key=random.PRNGKey(1), + num_warmup=10, + init_params=None, + model_args=(), + model_kwargs={}, + ) + fn = kernel.postprocess_fn((), {}) + assert callable(fn) + + +def test_mclmc_getstate_clears_postprocess_fn(): + kernel = MCLMC(model=_two_dim_model) + kernel._postprocess_fn = lambda *args, **kwargs: lambda x: x + state = kernel.__getstate__() + assert state["_postprocess_fn"] is None + assert state["_model"] is _two_dim_model + + +def test_mclmc_adaptation_l_handles_bad_ess(monkeypatch): + """Test ESS guard keeps L finite for degenerate ESS estimates.""" + state = mclmc_module.MCLMCState( + position=jnp.array([0.0, 0.0]), + momentum=jnp.array([1.0, 0.0]), + logdensity=jnp.array(0.0), + logdensity_grad=jnp.array([0.0, 0.0]), + ) + params = mclmc_module.MCLMCAdaptationState( + L=jnp.array(1.0), + step_size=jnp.array(0.1), + inverse_mass_matrix=jnp.ones((2,)), + ) + + def dummy_kernel(rng_key, state, L, step_size): + del rng_key, L, step_size + return state, mclmc_module.MCLMCInfo( + logdensity=state.logdensity, + kinetic_change=jnp.array(0.0), + energy_change=jnp.array(0.0), + ) + + monkeypatch.setattr( + mclmc_module, + "effective_sample_size", + lambda _: jnp.array([0.0, jnp.nan, jnp.inf]), + ) + + adaptation_l = mclmc_module._make_adaptation_l( + kernel=dummy_kernel, + frac=0.5, + lfactor=0.4, + ) + _, new_params = adaptation_l( + state=state, + params=params, + num_steps=10, + rng_key=random.PRNGKey(0), + ) + assert jnp.isfinite(new_params.L) + assert new_params.L > 0