| Authors | Matvei Kreinin, Maria Nikitina, Petr Babkin, Iryna Zabarianska |
| Consultant | Oleg Bakhteev, PhD |
| Paper | Figurnov et al., Implicit Reparameterization Gradients, NeurIPS 2018 |
This library implements implicit reparameterization gradients for continuous distributions
that lack tractable inverse CDFs. It provides drop-in replacements for torch.distributions
classes with full support for reparameterized sampling (rsample), enabling
gradient-based optimization through stochastic nodes.
The key idea from the paper: instead of inverting the CDF explicitly, compute reparameterization gradients via implicit differentiation:
| Distribution | Parameters | Method |
|---|---|---|
Normal |
loc, scale | Implicit standardization |
Gamma |
concentration, rate | Implicit CDF + scaling |
Beta |
concentration1, concentration0 | Via Gamma ratio |
Dirichlet |
concentration | Via Gamma normalization |
StudentT |
df, loc, scale | Via Gamma-Normal mixture |
VonMises |
loc, concentration | CDF series / normal approx. |
MixtureSameFamily |
mixture, components | Distributional transform |
ImplicitReparam |
any base distribution | Universal CDF wrapper (Eq. 8) |
git clone https://github.com/intsystems/implicit-reparameterization-trick.git
cd implicit-reparameterization-trick
pip install src/Reparameterized sampling from a Beta distribution:
import torch
from irt.distributions import Beta
alpha = torch.tensor([2.0], requires_grad=True)
beta = torch.tensor([5.0], requires_grad=True)
dist = Beta(alpha, beta)
z = dist.rsample(torch.Size([64])) # gradients flow to alpha and betaWrapping any distribution with a tractable CDF via ImplicitReparam:
import torch
from irt.distributions import ImplicitReparam
loc = torch.tensor(0.0, requires_grad=True)
base = torch.distributions.Laplace(loc, 1.0)
dist = ImplicitReparam(base)
z = dist.rsample(torch.Size([64])) # gradients flow to locMixture of distributions:
import torch
from torch.distributions import Categorical
from irt.distributions import Normal, MixtureSameFamily
mix_weights = Categorical(torch.tensor([0.3, 0.7]))
components = Normal(
torch.tensor([-1.0, 1.0], requires_grad=True),
torch.tensor([0.5, 0.5]),
)
mixture = MixtureSameFamily(mix_weights, components)
z = mixture.rsample(torch.Size([64]))VAE trained on dynamically binarized MNIST following the setup in Table 4 of the paper.
Architecture: FC encoder (784-256-128) and decoder (128-256-784), 30 epochs, Adam optimizer
with KL annealing. Results are averaged over 3 random seeds.
Full reproduction in code/vae_demo.ipynb.
Lower is better. Each cell shows mean and standard deviation over 3 runs.
Encodings of the MNIST test set in 2D latent space, colored by digit class. Each panel corresponds to a different posterior distribution family.
Samples drawn from the prior of each D=2 model and decoded into images.
- M. Figurnov, S. Mohamed, A. Mnih. Implicit Reparameterization Gradients. NeurIPS 2018.
- Documentation
- Blog Post


