From 04ddc7794981884cdce96bab3a341ca69340786b Mon Sep 17 00:00:00 2001 From: datvo06 Date: Tue, 26 May 2026 16:16:00 -0400 Subject: [PATCH 1/8] Fix `_DistributionTerm.expand` return-type annotation `expand` was `@defop`-annotated to return `jax.Array`, which routed its result through `_ArrayTerm` and broke chained distribution methods like `Normal(mu_term, 1.0).expand([J]).to_event(1)`. Annotate the return type as `dist.Distribution` so the result is no longer dispatched as an array. Note: `to_event` was already correctly annotated. --- effectful/handlers/numpyro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/handlers/numpyro.py b/effectful/handlers/numpyro.py index f0369d37..0791daca 100644 --- a/effectful/handlers/numpyro.py +++ b/effectful/handlers/numpyro.py @@ -357,7 +357,7 @@ def to_event(self, reinterpreted_batch_ndims=None) -> dist.Distribution: raise NotHandled @defop - def expand(self, batch_shape) -> jax.Array: + def expand(self, batch_shape) -> dist.Distribution: if not self._is_eager: raise NotHandled From fadd5087e8f4be67970935e08f095832fb3d70cb Mon Sep 17 00:00:00 2001 From: datvo06 Date: Tue, 26 May 2026 16:30:32 -0400 Subject: [PATCH 2/8] Register `_DistributionMethodTerm` fallback for `dist.Distribution` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Without this, `defdata` dispatch on `dist.Distribution` falls through to `collections.abc.Callable` (since `Distribution` defines `__call__`) and produces `_CallableTerm`, breaking chained distribution methods like `Normal(mu_term, 1.0).expand([J]).to_event(1)`. The fallback subclasses `_DistributionTerm` and carries forward the receiver's distribution family, so further chained methods remain available on the resulting term. Add a regression test pinning the chain-construction behavior. Note: this resolves the AttributeError reported in #666. Running the full repro under MCMC still fails downstream because other distribution property defops (`support`, `batch_shape`, ...) raise `NotHandled` on non-eager receivers and route through the same defdata pattern — a broader integration matter, not the chain-construction bug. --- effectful/handlers/numpyro.py | 26 ++++++++++++++++++++++++++ tests/test_handlers_numpyro.py | 20 ++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/effectful/handlers/numpyro.py b/effectful/handlers/numpyro.py index 0791daca..c4f57b1d 100644 --- a/effectful/handlers/numpyro.py +++ b/effectful/handlers/numpyro.py @@ -396,6 +396,32 @@ def __str__(self): expand = _DistributionTerm.expand +@defdata.register(dist.Distribution) +class _DistributionMethodTerm(_DistributionTerm): + """Fallback term for distribution-method ops whose return annotation is the + abstract ``dist.Distribution`` (e.g. ``expand`` and ``to_event`` when the + receiver is non-eager). + + Without this registration, ``defdata`` dispatch on ``dist.Distribution`` falls + through to ``collections.abc.Callable`` (since ``Distribution`` defines + ``__call__``) and produces a ``_CallableTerm``, breaking chained method calls + like ``Normal(mu_term, 1.0).expand([J]).to_event(1)``. See #666. + + The receiver of the deferred op is taken as the first positional argument; we + carry forward its distribution family so further chained methods remain + available on the resulting term. + """ + + def __init__(self, ty, op, *args, **kwargs): + receiver = args[0] if args else None + constr = ( + receiver._constr + if isinstance(receiver, _DistributionTerm) + else dist.Distribution + ) + super().__init__(constr, op, *args, **kwargs) + + @defop def Cauchy(loc=0.0, scale=1.0, **kwargs) -> dist.Cauchy: raise NotHandled diff --git a/tests/test_handlers_numpyro.py b/tests/test_handlers_numpyro.py index 270def25..327be8ba 100644 --- a/tests/test_handlers_numpyro.py +++ b/tests/test_handlers_numpyro.py @@ -929,3 +929,23 @@ def test_distribution_typeof(): typeof(dist.Normal(jax_getitem(jnp.array([0, 1, 2]), [defop(jax.Array)()]))) is numpyro.distributions.continuous.Normal ) + + +def test_distribution_method_chain_on_non_eager_term(): + """Regression test for #666. + + ``Normal(mu_term, 1.0).expand([J]).to_event(1)`` must not raise + ``AttributeError`` mid-chain. Previously ``_DistributionTerm.expand`` was + ``@defop``-annotated to return ``jax.Array``, routing ``.expand([J])``'s + result through ``_ArrayTerm`` (no ``.to_event``). The fix annotates + ``expand`` to return ``dist.Distribution`` and registers a fallback + ``_DistributionMethodTerm`` for ``defdata`` dispatch on the abstract base, + so the chain stays in the distribution-term surface. + """ + mu = defop(jax.Array, name="mu") + + expanded = dist.Normal(mu(), 1.0).expand([3]) + assert isinstance(expanded, numpyro.distributions.Distribution) + + chained = expanded.to_event(1) + assert isinstance(chained, numpyro.distributions.Distribution) From 3a2b467ea1f384adf1a93711bf55f1dac8101091 Mon Sep 17 00:00:00 2001 From: datvo06 Date: Tue, 26 May 2026 16:40:55 -0400 Subject: [PATCH 3/8] Add end-to-end MCMC regression test for #666 The original issue framed the bug as blocking "the standard NumPyro vectorised hierarchical-model idiom". Add a test that exercises the idiomatic form (`numpyro.plate` around a sample) end-to-end under MCMC and asserts sample shapes. Companion to the existing narrow chain test. --- tests/test_handlers_numpyro.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_handlers_numpyro.py b/tests/test_handlers_numpyro.py index 327be8ba..4d5fd9db 100644 --- a/tests/test_handlers_numpyro.py +++ b/tests/test_handlers_numpyro.py @@ -932,7 +932,7 @@ def test_distribution_typeof(): def test_distribution_method_chain_on_non_eager_term(): - """Regression test for #666. + """Regression test for #666 (narrow). ``Normal(mu_term, 1.0).expand([J]).to_event(1)`` must not raise ``AttributeError`` mid-chain. Previously ``_DistributionTerm.expand`` was @@ -949,3 +949,30 @@ def test_distribution_method_chain_on_non_eager_term(): chained = expanded.to_event(1) assert isinstance(chained, numpyro.distributions.Distribution) + + +def test_vectorised_hierarchical_model_mcmc(): + """End-to-end regression for #666. + + The issue framed the bug as blocking "the standard NumPyro vectorised + hierarchical-model idiom". With the annotation fix, the idiomatic form + (``numpyro.plate`` around a sample) traces correctly under MCMC and + produces samples of the expected shape. + """ + import jax.random as jr + + def model(): + mu = numpyro.sample("mu", dist.Normal(0.0, 1.0)) + with numpyro.plate("j", 3): + numpyro.sample("theta", dist.Normal(mu, 1.0)) + + mcmc = numpyro.infer.MCMC( + numpyro.infer.NUTS(model), + num_warmup=20, + num_samples=20, + progress_bar=False, + ) + mcmc.run(jr.PRNGKey(0)) + samples = mcmc.get_samples() + assert samples["mu"].shape == (20,) + assert samples["theta"].shape == (20, 3) From 92e3d0589616cd596ec92b7091d1a055f1d88d3d Mon Sep 17 00:00:00 2001 From: datvo06 Date: Tue, 26 May 2026 16:43:17 -0400 Subject: [PATCH 4/8] Remove non-regression end-to-end test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The plate-based MCMC test was using `Normal(mu, 1.0)` where `mu` is a sampled real array, so the receiver was fully eager and `.expand` / `.to_event` were never called. The test passed on master too — it didn't pin anything #666-specific. The narrow chain test does the actual job and is verified to fail on master. --- tests/test_handlers_numpyro.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/tests/test_handlers_numpyro.py b/tests/test_handlers_numpyro.py index 4d5fd9db..1b0c3bf3 100644 --- a/tests/test_handlers_numpyro.py +++ b/tests/test_handlers_numpyro.py @@ -951,28 +951,3 @@ def test_distribution_method_chain_on_non_eager_term(): assert isinstance(chained, numpyro.distributions.Distribution) -def test_vectorised_hierarchical_model_mcmc(): - """End-to-end regression for #666. - - The issue framed the bug as blocking "the standard NumPyro vectorised - hierarchical-model idiom". With the annotation fix, the idiomatic form - (``numpyro.plate`` around a sample) traces correctly under MCMC and - produces samples of the expected shape. - """ - import jax.random as jr - - def model(): - mu = numpyro.sample("mu", dist.Normal(0.0, 1.0)) - with numpyro.plate("j", 3): - numpyro.sample("theta", dist.Normal(mu, 1.0)) - - mcmc = numpyro.infer.MCMC( - numpyro.infer.NUTS(model), - num_warmup=20, - num_samples=20, - progress_bar=False, - ) - mcmc.run(jr.PRNGKey(0)) - samples = mcmc.get_samples() - assert samples["mu"].shape == (20,) - assert samples["theta"].shape == (20, 3) From 67068fe97c4960de1613cd1bbc44c2eff5de5005 Mon Sep 17 00:00:00 2001 From: datvo06 Date: Tue, 26 May 2026 17:26:05 -0400 Subject: [PATCH 5/8] Teach _DistributionMethodTerm to materialise via the receiver chain MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review feedback, no new term classes. Two minimal changes: 1. ``_DistributionTerm._is_eager`` (on the base) now recurses into Term arguments that are themselves ``_DistributionTerm``s, consulting their ``_is_eager``. Previously any ``_DistributionTerm``-valued arg was treated as non-eager (since ``is_eager_array`` returned False for Distribution Terms), which prevented downstream methods from resolving even when the underlying base was eager. 2. ``_DistributionMethodTerm._pos_base_dist`` builds a real NumPyro distribution by recursively materialising the receiver and applying the deferred op (``to_event`` / ``expand``) directly. With this and the ``_is_eager`` fix, every inherited ``@defop`` method (``batch_shape``, ``event_shape``, ``support``, ``log_prob``, ``sample``, ...) resolves via the standard machinery — no parallel property/sample overrides needed. Adds two regression tests: - equational shape/support laws for ``.expand`` and ``.to_event`` on a free-variable receiver bound by an effectful handler - end-to-end MCMC over the literal #666 idiom (``Normal(mu_term, 1.0).expand([3]).to_event(1)`` with ``mu_term`` bound via ``handler``) Both fail on master with the expected error mode and pass on this branch. --- effectful/handlers/numpyro.py | 36 ++++++++++++++++-- tests/test_handlers_numpyro.py | 69 ++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 4 deletions(-) diff --git a/effectful/handlers/numpyro.py b/effectful/handlers/numpyro.py index c4f57b1d..c0bd0d76 100644 --- a/effectful/handlers/numpyro.py +++ b/effectful/handlers/numpyro.py @@ -226,10 +226,12 @@ def _pos_base_dist(self) -> dist.Distribution: @functools.cached_property def _is_eager(self) -> bool: - return all( - (not isinstance(x, Term) or is_eager_array(x)) - for x in (*self.args, *self.kwargs.values()) - ) + def _arg_is_eager(x): + if isinstance(x, _DistributionTerm): + return x._is_eager + return not isinstance(x, Term) or is_eager_array(x) + + return all(_arg_is_eager(x) for x in (*self.args, *self.kwargs.values())) @property def op(self): @@ -421,6 +423,32 @@ def __init__(self, ty, op, *args, **kwargs): ) super().__init__(constr, op, *args, **kwargs) + @functools.cached_property + def _pos_base_dist(self) -> dist.Distribution: + # Materialise the deferred method op against a real NumPyro distribution + # built from the (now-eager) receiver. Inherited ``@defop`` methods + # (``batch_shape``, ``event_shape``, ``support``, ``log_prob``, + # ``sample``, ...) then resolve via the standard ``_DistributionTerm`` + # machinery against this base. + receiver = self._args[0] + base = ( + receiver._pos_base_dist + if isinstance(receiver, _DistributionTerm) + else receiver + ) + if self._op is _DistributionTerm.to_event: + n = ( + self._args[1] + if len(self._args) > 1 + else self._kwargs.get("reinterpreted_batch_ndims") + ) + return base.to_event(n) + if self._op is _DistributionTerm.expand: + return base.expand(self._args[1]) + raise NotImplementedError( + f"_DistributionMethodTerm._pos_base_dist for op {self._op}" + ) + @defop def Cauchy(loc=0.0, scale=1.0, **kwargs) -> dist.Cauchy: diff --git a/tests/test_handlers_numpyro.py b/tests/test_handlers_numpyro.py index 1b0c3bf3..b5dacb4c 100644 --- a/tests/test_handlers_numpyro.py +++ b/tests/test_handlers_numpyro.py @@ -951,3 +951,72 @@ def test_distribution_method_chain_on_non_eager_term(): assert isinstance(chained, numpyro.distributions.Distribution) +def test_expand_to_event_shape_laws(): + """Equational laws for ``.expand`` and ``.to_event`` on a distribution term + whose free-variable arg has been bound by an effectful handler. + + These hold for any NumPyro distribution and should survive any future + refactor of how deferred method ops are encoded: + + d.expand(s).batch_shape == tuple(s) + d.expand(s).event_shape == d.event_shape + d.to_event(k).event_shape == d.batch_shape[-k:] + d.event_shape + d.to_event(k).batch_shape == d.batch_shape[:-k] + """ + import jax.numpy as jnp + + from effectful.ops.semantics import handler + + mu = defop(jax.Array, name="mu") + + with handler({mu: lambda: jnp.array(0.0)}): + d = dist.Normal(mu(), 1.0) + assert d.batch_shape == () + assert d.event_shape == () + + expanded = d.expand([3, 4]) + assert expanded.batch_shape == (3, 4) + assert expanded.event_shape == () + + indep = expanded.to_event(1) + assert indep.batch_shape == (3,) + assert indep.event_shape == (4,) + + chained = d.expand([3]).to_event(1) + assert chained.batch_shape == () + assert chained.event_shape == (3,) + assert not chained.support.is_discrete + + +def test_expand_to_event_chain_end_to_end_mcmc(): + """End-to-end regression: the literal #666 idiom — ``Normal(mu_term, 1.0) + .expand([J]).to_event(1)`` with ``mu_term`` bound by an effectful handler — + must trace, build a potential, and run MCMC to completion. + + Before the fix this raised ``AttributeError: '_ArrayTerm' object has no + attribute 'to_event'`` at chain construction. After the fix, the chain + constructs a ``_DistributionMethodTerm`` whose materialised + ``_pos_base_dist`` resolves to a real ``dist.Independent`` wrapping the + handler-bound receiver, so NumPyro's downstream property/sample/log_prob + accesses all resolve. + """ + import jax.numpy as jnp + import jax.random as jr + + from effectful.ops.semantics import handler + + mu = defop(jax.Array, name="mu") + + def model(): + numpyro.sample("theta", dist.Normal(mu(), 1.0).expand([3]).to_event(1)) + + with handler({mu: lambda: jnp.array(0.0)}): + mcmc = numpyro.infer.MCMC( + numpyro.infer.NUTS(model), + num_warmup=20, + num_samples=20, + progress_bar=False, + ) + mcmc.run(jr.PRNGKey(0)) + + assert mcmc.get_samples()["theta"].shape == (20, 3) From 154c845493569f8d5e97064dd0a9eeb0557a3994 Mon Sep 17 00:00:00 2001 From: datvo06 Date: Tue, 26 May 2026 17:36:24 -0400 Subject: [PATCH 6/8] Materialise via getattr instead of per-op switch; un-xfail to_event cases Two cleanups: 1. `_DistributionMethodTerm._pos_base_dist` previously branched on ``self._op is _DistributionTerm.to_event`` / ``... is .expand`` to pick the materialisation. Replace the if-chain with a single ``getattr(base, self._op.__name__)(*self._args[1:], **self._kwargs)``. Same semantics for the two existing ops, and now correct for any future ``dist.Distribution``-returning method op without a per-op switch. 2. Remove ``xfail="to_event not implemented"`` from the ``Beta(...).to_event(k)`` and ``Dirichlet(...).to_event(k)`` parametrised cases; they pass now under this PR. The ``Independent(TransformedDistribution(...), k)`` case stays xfail but the reason is updated to ``"TransformedDistribution not implemented"`` to reflect the actual blocker (TransformedDistribution itself is unsupported in effectful). --- effectful/handlers/numpyro.py | 21 ++++----------------- tests/test_handlers_numpyro.py | 4 +--- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/effectful/handlers/numpyro.py b/effectful/handlers/numpyro.py index c0bd0d76..f5401d22 100644 --- a/effectful/handlers/numpyro.py +++ b/effectful/handlers/numpyro.py @@ -425,29 +425,16 @@ def __init__(self, ty, op, *args, **kwargs): @functools.cached_property def _pos_base_dist(self) -> dist.Distribution: - # Materialise the deferred method op against a real NumPyro distribution - # built from the (now-eager) receiver. Inherited ``@defop`` methods - # (``batch_shape``, ``event_shape``, ``support``, ``log_prob``, - # ``sample``, ...) then resolve via the standard ``_DistributionTerm`` - # machinery against this base. + # Materialise by recursively building the receiver's _pos_base_dist + # and invoking NumPyro's method of the same name on it. Works for any + # ``dist.Distribution``-returning method op without a per-op switch. receiver = self._args[0] base = ( receiver._pos_base_dist if isinstance(receiver, _DistributionTerm) else receiver ) - if self._op is _DistributionTerm.to_event: - n = ( - self._args[1] - if len(self._args) > 1 - else self._kwargs.get("reinterpreted_batch_ndims") - ) - return base.to_event(n) - if self._op is _DistributionTerm.expand: - return base.expand(self._args[1]) - raise NotImplementedError( - f"_DistributionMethodTerm._pos_base_dist for op {self._op}" - ) + return getattr(base, self._op.__name__)(*self._args[1:], **self._kwargs) @defop diff --git a/tests/test_handlers_numpyro.py b/tests/test_handlers_numpyro.py index b5dacb4c..b993a5fd 100644 --- a/tests/test_handlers_numpyro.py +++ b/tests/test_handlers_numpyro.py @@ -480,7 +480,6 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None): ("concentration0", f"exp(rand({batch_shape + indep_shape}))"), ), batch_shape, - xfail="to_event not implemented", ) # Dirichlet.to_event @@ -494,7 +493,6 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None): ), ), batch_shape, - xfail="to_event not implemented", ) # TransformedDistribution.to_event @@ -513,7 +511,7 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None): ("high", f"2. + rand({batch_shape + indep_shape})"), ), batch_shape, - xfail="to_event not implemented", + xfail="TransformedDistribution not implemented", ) From af7e31471b690ad1f0d16dbb6caf631619bc016d Mon Sep 17 00:00:00 2001 From: datvo06 Date: Tue, 26 May 2026 17:42:02 -0400 Subject: [PATCH 7/8] Tighten _DistributionMethodTerm docstring and comment --- effectful/handlers/numpyro.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/effectful/handlers/numpyro.py b/effectful/handlers/numpyro.py index f5401d22..329cf471 100644 --- a/effectful/handlers/numpyro.py +++ b/effectful/handlers/numpyro.py @@ -400,19 +400,9 @@ def __str__(self): @defdata.register(dist.Distribution) class _DistributionMethodTerm(_DistributionTerm): - """Fallback term for distribution-method ops whose return annotation is the - abstract ``dist.Distribution`` (e.g. ``expand`` and ``to_event`` when the - receiver is non-eager). - - Without this registration, ``defdata`` dispatch on ``dist.Distribution`` falls - through to ``collections.abc.Callable`` (since ``Distribution`` defines - ``__call__``) and produces a ``_CallableTerm``, breaking chained method calls - like ``Normal(mu_term, 1.0).expand([J]).to_event(1)``. See #666. - - The receiver of the deferred op is taken as the first positional argument; we - carry forward its distribution family so further chained methods remain - available on the resulting term. - """ + """Term for distribution-method ops returning the abstract ``dist.Distribution`` + (``expand``, ``to_event``). Catches the ``defdata`` fallthrough that would + otherwise hit ``_CallableTerm``. See #666.""" def __init__(self, ty, op, *args, **kwargs): receiver = args[0] if args else None @@ -425,9 +415,7 @@ def __init__(self, ty, op, *args, **kwargs): @functools.cached_property def _pos_base_dist(self) -> dist.Distribution: - # Materialise by recursively building the receiver's _pos_base_dist - # and invoking NumPyro's method of the same name on it. Works for any - # ``dist.Distribution``-returning method op without a per-op switch. + # Delegate to NumPyro's method of the same name on the materialised receiver. receiver = self._args[0] base = ( receiver._pos_base_dist From 9d793588c1052aa742fc12cfc5ba241bc953af53 Mon Sep 17 00:00:00 2001 From: datvo06 Date: Tue, 26 May 2026 17:59:58 -0400 Subject: [PATCH 8/8] Revert un-xfail of parametrised Beta/Dirichlet to_event cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The basic ``D.to_event(k)`` chain works after this PR, but ``test_dist_expand`` then composes ``.expand_by`` on the resulting ``_DistributionMethodTerm`` whose receiver carries indexed (named-dim) arrays. Materialising via the receiver's ``_pos_base_dist`` and then ``base.expand(...)`` mishandles those indices (``Cannot broadcast distribution of shape (5, 3) to shape (3,)``) — that's a deeper indexed-dim integration matter, out of scope for the #666 annotation fix. Restore the xfail with a more accurate reason. --- tests/test_handlers_numpyro.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_handlers_numpyro.py b/tests/test_handlers_numpyro.py index b993a5fd..befd5cbe 100644 --- a/tests/test_handlers_numpyro.py +++ b/tests/test_handlers_numpyro.py @@ -480,6 +480,7 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None): ("concentration0", f"exp(rand({batch_shape + indep_shape}))"), ), batch_shape, + xfail="to_event composed with expand_by on indexed dims not implemented", ) # Dirichlet.to_event @@ -493,6 +494,7 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None): ), ), batch_shape, + xfail="to_event composed with expand_by on indexed dims not implemented", ) # TransformedDistribution.to_event