From e0d920ac896e303c8c216393d4454420c4b3bcd7 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 11 Jun 2026 22:06:24 +0800 Subject: [PATCH] fix(random): correct six distribution bugs found in audit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Audit of brainstate.random surfaced six reachable correctness bugs, each contradicting the function's own NumPy-style docstring. All are fixed with test-first regression coverage (TestAuditRegressions in _fun_test.py). - standard_t: with array `df` and `size=None`, output shape was always () (dead `u.math.shape(size)` branch) and raised ValueError. Now infers the shape from `df`, matching the sibling `t` and the docstring. - weibull_min: divided by `scale` instead of multiplying. Now `r * scale`, matching scipy.stats.weibull_min and the `weibull` docstring (lambda scale). - triangular: was `2*bernoulli-1` (Rademacher ±1) with a size-only signature, so the documented `triangular(-3, 0, 8, N)` raised TypeError. Reimplemented as the real triangular(left, mode, right, size) via inverse-CDF, with shared-unit support like `uniform`. - geometric: off-by-one (support {0,1,...} instead of {1,2,...}) and returned float. Now `floor(...) + 1` cast to an integer dtype, so P(k==1)==p. - randint_like: default `high = max(input)` used the Python builtin and raised on >1-D templates. Now uses `u.math.max`. - chisquare: summed `df` squared normals, rejecting non-integer scalar `df` (TypeError) and array `df` with `size=None` (NotImplementedError). Now uses the `2 * Gamma(df/2)` relation, valid for any positive real / array `df`. Tests encoding the old buggy behavior (triangular ±1, chisquare NotImplementedError) are updated to assert the corrected contracts. --- brainstate/random/_fun.py | 18 ++++- brainstate/random/_fun_test.py | 117 +++++++++++++++++++++++++++++-- brainstate/random/_state.py | 93 ++++++++++++++++++------ brainstate/random/_state_test.py | 20 +++--- 4 files changed, 211 insertions(+), 37 deletions(-) diff --git a/brainstate/random/_fun.py b/brainstate/random/_fun.py index a95a70c..7806bc6 100644 --- a/brainstate/random/_fun.py +++ b/brainstate/random/_fun.py @@ -3350,9 +3350,13 @@ def rayleigh( @set_module_as('brainstate.random') def triangular( + left: ArrayLike = 0.0, + mode: ArrayLike = 0.5, + right: ArrayLike = 1.0, size: Optional[Size] = None, - key: Optional[SeedOrKey] = None -) -> jax.Array: + key: Optional[SeedOrKey] = None, + dtype: Optional[DTypeLike] = None +) -> Union[jax.Array, u.Quantity]: r""" Draw samples from the triangular distribution over the interval ``[left, right]``. @@ -3364,6 +3368,14 @@ def triangular( Parameters ---------- + left : float or array_like of floats, optional + Lower limit of the distribution. Default is 0.0. + mode : float or array_like of floats, optional + The value where the peak of the distribution occurs, must satisfy + ``left <= mode <= right``. Default is 0.5. + right : float or array_like of floats, optional + Upper limit of the distribution, must be larger than ``left``. + Default is 1.0. size : int or tuple of ints, optional Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` samples are drawn. If size is ``None`` (default), @@ -3409,7 +3421,7 @@ def triangular( ... density=True) >>> plt.show() """ - return DEFAULT.triangular(size=size, key=key) + return DEFAULT.triangular(left, mode, right, size=size, key=key, dtype=dtype) @set_module_as('brainstate.random') diff --git a/brainstate/random/_fun_test.py b/brainstate/random/_fun_test.py index 5da60bb..c3d8c6f 100644 --- a/brainstate/random/_fun_test.py +++ b/brainstate/random/_fun_test.py @@ -404,9 +404,11 @@ def test_chisquare1(self): self.assertTrue(a.dtype, float) def test_chisquare2(self): + # Array ``df`` with ``size=None`` infers the shape from ``df`` (gamma-based + # implementation supports non-scalar and non-integer degrees of freedom). brainstate.random.seed() - with self.assertRaises(NotImplementedError): - a = brainstate.random.chisquare(df=[2, 3, 4]) + a = brainstate.random.chisquare(df=[2, 3, 4]) + self.assertTupleEqual(a.shape, (3,)) def test_chisquare3(self): brainstate.random.seed() @@ -542,8 +544,9 @@ def test_rayleigh(self): def test_triangular(self): brainstate.random.seed() - a = brainstate.random.triangular((2, 2)) + a = brainstate.random.triangular(-3.0, 0.0, 8.0, (2, 2)) self.assertTupleEqual(a.shape, (2, 2)) + self.assertTrue(((a >= -3.0) & (a <= 8.0)).all()) def test_vonmises(self): brainstate.random.seed() @@ -680,7 +683,7 @@ def test_t2(self): ("chisquare", lambda size: brainstate.random.chisquare(3, size)), ("t", lambda size: brainstate.random.t(5.0, size)), ("standard_t", lambda size: brainstate.random.standard_t(5.0, size)), - ("triangular", lambda size: brainstate.random.triangular(size)), + ("triangular", lambda size: brainstate.random.triangular(0.0, 0.5, 1.0, size)), ("vonmises", lambda size: brainstate.random.vonmises(0.0, 1.0, size)), ("maxwell", lambda size: brainstate.random.maxwell(size)), ("f", lambda size: brainstate.random.f(2.0, 5.0, size)), @@ -863,3 +866,109 @@ def test_poisson_statistics(self): x = brainstate.random.poisson(3.0, (10000,)) self.assertTrue(bool(jnp.all(x >= 0))) self.assertLess(abs(float(jnp.mean(x)) - 3.0), 0.2) + + +class TestAuditRegressions(parameterized.TestCase): + """Regression tests for bugs found in the brainstate.random audit.""" + + # --- A: standard_t array df with size=None ---------------------------------- + + def test_standard_t_array_df_infers_shape(self): + """standard_t infers output shape from array ``df`` when ``size`` is None.""" + brainstate.random.seed(0) + a = brainstate.random.standard_t([1.0, 2.0, 4.0]) + self.assertTupleEqual(tuple(a.shape), (3,)) + self.assertTrue(bool(jnp.all(jnp.isfinite(a)))) + + def test_standard_t_scalar_df_unchanged(self): + """standard_t with a scalar ``df`` still returns a scalar.""" + brainstate.random.seed(0) + self.assertTupleEqual(tuple(brainstate.random.standard_t(3.0).shape), ()) + + # --- B: weibull_min scale multiplies ----------------------------------------- + + @pytest.mark.slow + def test_weibull_min_scale_multiplies(self): + """weibull_min(a, scale) scales the standard draw by ``scale`` (not 1/scale).""" + brainstate.random.seed(0) + base = np.asarray(brainstate.random.weibull(2.0, 200000)) + brainstate.random.seed(0) + scaled = np.asarray(brainstate.random.weibull_min(2.0, 4.0, 200000)) + # Same key/uniforms, so the ratio is exactly the scale factor elementwise. + np.testing.assert_allclose(scaled / base, 4.0, rtol=1e-4) + + # --- C: triangular is a real triangular distribution ------------------------- + + def test_triangular_within_bounds(self): + """triangular(left, mode, right) stays within [left, right].""" + brainstate.random.seed(0) + a = np.asarray(brainstate.random.triangular(-3.0, 0.0, 8.0, 1000)) + self.assertTupleEqual(a.shape, (1000,)) + self.assertTrue((a >= -3.0).all() and (a <= 8.0).all()) + + def test_triangular_scalar(self): + """triangular returns a scalar when all parameters are scalars and size is None.""" + brainstate.random.seed(0) + self.assertTupleEqual(tuple(brainstate.random.triangular(0.0, 0.5, 1.0).shape), ()) + + @pytest.mark.slow + def test_triangular_mode_skews_mean(self): + """The sample mean approaches the analytic mean (left+mode+right)/3.""" + brainstate.random.seed(0) + a = np.asarray(brainstate.random.triangular(0.0, 2.0, 3.0, 200000)) + self.assertLess(abs(a.mean() - (0.0 + 2.0 + 3.0) / 3.0), 0.02) + + def test_triangular_docstring_example_runs(self): + """The documented ``triangular(-3, 0, 8, N)`` call no longer raises.""" + brainstate.random.seed(0) + a = brainstate.random.triangular(-3, 0, 8, 100) + self.assertTupleEqual(tuple(a.shape), (100,)) + + # --- D: geometric off-by-one and integer dtype ------------------------------- + + def test_geometric_support_starts_at_one(self): + """geometric is supported on the positive integers {1, 2, ...}.""" + brainstate.random.seed(0) + a = np.asarray(brainstate.random.geometric(0.5, size=(5000,))) + self.assertGreaterEqual(int(a.min()), 1) + self.assertTrue(np.issubdtype(a.dtype, np.integer)) + + @pytest.mark.slow + def test_geometric_pmf_first_success(self): + """P(k == 1) approaches ``p`` (NumPy convention).""" + brainstate.random.seed(0) + a = np.asarray(brainstate.random.geometric(0.35, size=(200000,))) + self.assertLess(abs((a == 1).mean() - 0.35), 0.01) + + # --- E: randint_like default high with ndim>1 input -------------------------- + + def test_randint_like_multidim_default_high(self): + """randint_like infers ``high`` from a multi-dimensional template.""" + brainstate.random.seed(0) + template = jnp.array([[3, 7], [2, 9]]) + a = brainstate.random.randint_like(template) + self.assertTupleEqual(tuple(a.shape), (2, 2)) + self.assertTrue(bool(jnp.all(a >= 0)) and bool(jnp.all(a < 9))) + + # --- F: chisquare accepts float and array df --------------------------------- + + def test_chisquare_float_scalar_df(self): + """chisquare accepts a non-integer scalar ``df``.""" + brainstate.random.seed(0) + a = brainstate.random.chisquare(3.5) + self.assertTupleEqual(tuple(a.shape), ()) + self.assertTrue(float(a) >= 0.0) + + def test_chisquare_array_df_infers_shape(self): + """chisquare with array ``df`` and ``size=None`` infers the shape from ``df``.""" + brainstate.random.seed(0) + a = brainstate.random.chisquare(jnp.array([2.0, 3.0, 4.0])) + self.assertTupleEqual(tuple(a.shape), (3,)) + self.assertTrue(bool(jnp.all(a >= 0.0))) + + @pytest.mark.slow + def test_chisquare_mean_matches_df(self): + """A large chi-square sample has mean ~ df.""" + brainstate.random.seed(0) + a = np.asarray(brainstate.random.chisquare(7.0, size=(100000,))) + self.assertLess(abs(a.mean() - 7.0), 0.1) diff --git a/brainstate/random/_state.py b/brainstate/random/_state.py index d7fc35a..78e6b30 100644 --- a/brainstate/random/_state.py +++ b/brainstate/random/_state.py @@ -647,7 +647,9 @@ def standard_t( ): df = _remove_unit_param('df', _check_py_seq(df)) if size is None: - size = u.math.shape(size) if size is not None else () + # Match numpy: with no explicit size, draw one sample per ``df`` entry + # (a scalar ``df`` yields a scalar). + size = u.math.shape(df) key = self.__get_key(key) dtype = dtype or environ.dftype() r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype) @@ -867,14 +869,13 @@ def chisquare( key = self.__get_key(key) dtype = dtype or environ.dftype() if size is None: - if jnp.ndim(df) == 0: - dist = jr.normal(key, (df,), dtype=dtype) ** 2 - dist = dist.sum() - else: - raise NotImplementedError('Do not support non-scale "df" when "size" is None') - else: - dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2 - dist = dist.sum(axis=0) + size = u.math.shape(df) + # A chi-square distribution with ``df`` degrees of freedom is ``2 * Gamma(df/2)``. + # Using the gamma relation (rather than summing ``df`` squared normals) is valid + # for any positive real and array-valued ``df``, matches numpy, and mirrors the + # sibling ``t`` / ``noncentral_chisquare`` implementations. + df = u.math.asarray(df, dtype=dtype) + dist = 2.0 * jr.gamma(key, 0.5 * df, shape=_size2shape(size), dtype=dtype) return dist @named_scope( @@ -911,10 +912,14 @@ def geometric( if size is None: size = u.math.shape(p) key = self.__get_key(key) - dtype = dtype or environ.dftype() - u_ = jr.uniform(key, size, dtype) - r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p)) - return r + float_dtype = environ.dftype() + u_ = jr.uniform(key, _size2shape(size), float_dtype) + # Inverse-CDF sampling. ``floor(log1p(-u) / log1p(-p))`` is supported on + # {0, 1, ...} (number of failures); the geometric distribution counts the + # number of trials until the first success and is supported on {1, 2, ...} + # with ``P(k == 1) == p``, so add one. Samples are integer-valued. + r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p)) + 1 + return u.math.asarray(r, dtype=dtype or environ.ditype()) def _check_p2(self, p): raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}') @@ -1025,18 +1030,59 @@ def rayleigh( @named_scope( 'brainstate/random', - static_argnums=(0, 1), - static_argnames=['size'] + static_argnums=(0, 4, 6), + static_argnames=['dtype', 'size'] ) def triangular( self, + left=0.0, + mode=0.5, + right=1.0, size: Optional[Size] = None, - key: Optional[SeedOrKey] = None + key: Optional[SeedOrKey] = None, + dtype: DTypeLike = None ): + dtype = dtype or environ.dftype() + # ``left``/``mode``/``right`` share a single physical unit, inferred from + # whichever bound carries one (a plain bound is then interpreted in that shared + # unit). A compatible-but-different unit is converted; an incompatible one raises. + values = [ + u.math.asarray(_check_py_seq(v), dtype=dtype) + for v in (left, mode, right) + ] + unit = u.UNITLESS + for v in values: + q = u.Quantity(v) + if not q.is_unitless: + unit = q.unit + break + + def _to_shared_unit(v): + q = u.Quantity(v) + return q.mantissa if q.is_unitless else q.in_unit(unit).mantissa + + left, mode, right = (_to_shared_unit(v) for v in values) + + if size is None: + size = lax.broadcast_shapes( + u.math.shape(left), + u.math.shape(mode), + u.math.shape(right), + ) + size = _size2shape(size) key = self.__get_key(key) - bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size)) - r = 2 * bernoulli_samples - 1 - return r + samples = jr.uniform(key, size, dtype=dtype) + + # Inverse-CDF (quantile) transform of the triangular distribution. With + # ``fc = (mode - left) / (right - left)`` the cut point of the unit uniform, + # draws below ``fc`` map onto the rising edge and the rest onto the falling + # edge. ``where(span > 0, ...)`` guards the degenerate ``left == right`` case. + span = right - left + fc = jnp.where(span > 0, (mode - left) / jnp.where(span > 0, span, 1.0), 0.0) + rising = left + jnp.sqrt(samples * span * (mode - left)) + falling = right - jnp.sqrt((1.0 - samples) * span * (right - mode)) + r = jnp.where(samples < fc, rising, falling) + return u.maybe_decimal(r * unit) @named_scope( 'brainstate/random', @@ -1118,7 +1164,10 @@ def weibull_min( random_uniform = jr.uniform(key=key, shape=size, dtype=dtype) r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) if scale_m is not None: - r = r / scale_m + # ``scale`` is the distribution scale parameter lambda: a 2-parameter + # Weibull draw is ``lambda * (-ln(1-U))**(1/a)`` (matching scipy + # ``weibull_min`` and the ``weibull`` docstring), so multiply by scale. + r = r * scale_m return u.maybe_decimal(r * unit) @named_scope( @@ -1527,7 +1576,9 @@ def randint_like( key: Optional[SeedOrKey] = None ): if high is None: - high = max(input) + # ``max`` (the Python builtin) raises on multi-dimensional arrays; use an + # array reduction so any-rank templates work. + high = u.math.max(input) return self.randint(low, high, size=u.math.shape(input), dtype=dtype, key=key) diff --git a/brainstate/random/_state_test.py b/brainstate/random/_state_test.py index b3e5fa0..69d497c 100644 --- a/brainstate/random/_state_test.py +++ b/brainstate/random/_state_test.py @@ -861,10 +861,11 @@ def test_chisquare_scalar_and_sized(self): self.assertEqual(self.rs.chisquare(3).shape, ()) self.assertEqual(self.rs.chisquare(3, size=(4,)).shape, (4,)) - def test_chisquare_nonscalar_df_requires_size(self): - """chisquare with non-scalar df and no size is unsupported.""" - with self.assertRaises(NotImplementedError): - self.rs.chisquare(jnp.array([2, 3])) + def test_chisquare_nonscalar_df_infers_shape(self): + """chisquare with non-scalar df and no size infers the shape from df.""" + arr = self.rs.chisquare(jnp.array([2.0, 3.0])) + self.assertEqual(arr.shape, (2,)) + self.assertTrue((arr >= 0).all()) def test_dirichlet(self): """dirichlet rows sum to one over the simplex axis.""" @@ -873,10 +874,11 @@ def test_dirichlet(self): np.testing.assert_allclose(np.asarray(arr).sum(axis=-1), 1.0, atol=1e-5) def test_geometric(self): - """geometric yields non-negative integer-valued samples.""" + """geometric yields integer-valued samples supported on {1, 2, ...}.""" arr = self.rs.geometric(0.5, size=(3,)) self.assertEqual(arr.shape, (3,)) - self.assertTrue((arr >= 0).all()) + self.assertTrue((arr >= 1).all()) + self.assertTrue(np.issubdtype(np.asarray(arr).dtype, np.integer)) def test_multinomial(self): """multinomial counts sum to n across the category axis.""" @@ -932,10 +934,10 @@ def test_rayleigh(self): self.assertTrue((arr >= 0).all()) def test_triangular(self): - """triangular returns values in {-1, 1}.""" - arr = self.rs.triangular(size=(50,)) + """triangular draws lie within [left, right] with the requested shape.""" + arr = self.rs.triangular(-1.0, 0.0, 2.0, size=(50,)) self.assertEqual(arr.shape, (50,)) - self.assertTrue(jnp.all((arr == -1) | (arr == 1))) + self.assertTrue(jnp.all((arr >= -1.0) & (arr <= 2.0))) def test_vonmises(self): """vonmises returns angles within (-pi, pi]."""