Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions brainstate/random/_fun.py
Original file line number Diff line number Diff line change
Expand Up @@ -3350,9 +3350,13 @@ def rayleigh(

@set_module_as('brainstate.random')
def triangular(
left: ArrayLike = 0.0,
mode: ArrayLike = 0.5,
right: ArrayLike = 1.0,
size: Optional[Size] = None,
key: Optional[SeedOrKey] = None
) -> jax.Array:
key: Optional[SeedOrKey] = None,
dtype: Optional[DTypeLike] = None
) -> Union[jax.Array, u.Quantity]:
r"""
Draw samples from the triangular distribution over the
interval ``[left, right]``.
Expand All @@ -3364,6 +3368,14 @@ def triangular(

Parameters
----------
left : float or array_like of floats, optional
Lower limit of the distribution. Default is 0.0.
mode : float or array_like of floats, optional
The value where the peak of the distribution occurs, must satisfy
``left <= mode <= right``. Default is 0.5.
right : float or array_like of floats, optional
Upper limit of the distribution, must be larger than ``left``.
Default is 1.0.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
Expand Down Expand Up @@ -3409,7 +3421,7 @@ def triangular(
... density=True)
>>> plt.show()
"""
return DEFAULT.triangular(size=size, key=key)
return DEFAULT.triangular(left, mode, right, size=size, key=key, dtype=dtype)


@set_module_as('brainstate.random')
Expand Down
117 changes: 113 additions & 4 deletions brainstate/random/_fun_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,11 @@ def test_chisquare1(self):
self.assertTrue(a.dtype, float)

def test_chisquare2(self):
# Array ``df`` with ``size=None`` infers the shape from ``df`` (gamma-based
# implementation supports non-scalar and non-integer degrees of freedom).
brainstate.random.seed()
with self.assertRaises(NotImplementedError):
a = brainstate.random.chisquare(df=[2, 3, 4])
a = brainstate.random.chisquare(df=[2, 3, 4])
self.assertTupleEqual(a.shape, (3,))

def test_chisquare3(self):
brainstate.random.seed()
Expand Down Expand Up @@ -542,8 +544,9 @@ def test_rayleigh(self):

def test_triangular(self):
brainstate.random.seed()
a = brainstate.random.triangular((2, 2))
a = brainstate.random.triangular(-3.0, 0.0, 8.0, (2, 2))
self.assertTupleEqual(a.shape, (2, 2))
self.assertTrue(((a >= -3.0) & (a <= 8.0)).all())

def test_vonmises(self):
brainstate.random.seed()
Expand Down Expand Up @@ -680,7 +683,7 @@ def test_t2(self):
("chisquare", lambda size: brainstate.random.chisquare(3, size)),
("t", lambda size: brainstate.random.t(5.0, size)),
("standard_t", lambda size: brainstate.random.standard_t(5.0, size)),
("triangular", lambda size: brainstate.random.triangular(size)),
("triangular", lambda size: brainstate.random.triangular(0.0, 0.5, 1.0, size)),
("vonmises", lambda size: brainstate.random.vonmises(0.0, 1.0, size)),
("maxwell", lambda size: brainstate.random.maxwell(size)),
("f", lambda size: brainstate.random.f(2.0, 5.0, size)),
Expand Down Expand Up @@ -863,3 +866,109 @@ def test_poisson_statistics(self):
x = brainstate.random.poisson(3.0, (10000,))
self.assertTrue(bool(jnp.all(x >= 0)))
self.assertLess(abs(float(jnp.mean(x)) - 3.0), 0.2)


class TestAuditRegressions(parameterized.TestCase):
"""Regression tests for bugs found in the brainstate.random audit."""

# --- A: standard_t array df with size=None ----------------------------------

def test_standard_t_array_df_infers_shape(self):
"""standard_t infers output shape from array ``df`` when ``size`` is None."""
brainstate.random.seed(0)
a = brainstate.random.standard_t([1.0, 2.0, 4.0])
self.assertTupleEqual(tuple(a.shape), (3,))
self.assertTrue(bool(jnp.all(jnp.isfinite(a))))

def test_standard_t_scalar_df_unchanged(self):
"""standard_t with a scalar ``df`` still returns a scalar."""
brainstate.random.seed(0)
self.assertTupleEqual(tuple(brainstate.random.standard_t(3.0).shape), ())

# --- B: weibull_min scale multiplies -----------------------------------------

@pytest.mark.slow
def test_weibull_min_scale_multiplies(self):
"""weibull_min(a, scale) scales the standard draw by ``scale`` (not 1/scale)."""
brainstate.random.seed(0)
base = np.asarray(brainstate.random.weibull(2.0, 200000))
brainstate.random.seed(0)
scaled = np.asarray(brainstate.random.weibull_min(2.0, 4.0, 200000))
# Same key/uniforms, so the ratio is exactly the scale factor elementwise.
np.testing.assert_allclose(scaled / base, 4.0, rtol=1e-4)

# --- C: triangular is a real triangular distribution -------------------------

def test_triangular_within_bounds(self):
"""triangular(left, mode, right) stays within [left, right]."""
brainstate.random.seed(0)
a = np.asarray(brainstate.random.triangular(-3.0, 0.0, 8.0, 1000))
self.assertTupleEqual(a.shape, (1000,))
self.assertTrue((a >= -3.0).all() and (a <= 8.0).all())

def test_triangular_scalar(self):
"""triangular returns a scalar when all parameters are scalars and size is None."""
brainstate.random.seed(0)
self.assertTupleEqual(tuple(brainstate.random.triangular(0.0, 0.5, 1.0).shape), ())

@pytest.mark.slow
def test_triangular_mode_skews_mean(self):
"""The sample mean approaches the analytic mean (left+mode+right)/3."""
brainstate.random.seed(0)
a = np.asarray(brainstate.random.triangular(0.0, 2.0, 3.0, 200000))
self.assertLess(abs(a.mean() - (0.0 + 2.0 + 3.0) / 3.0), 0.02)

def test_triangular_docstring_example_runs(self):
"""The documented ``triangular(-3, 0, 8, N)`` call no longer raises."""
brainstate.random.seed(0)
a = brainstate.random.triangular(-3, 0, 8, 100)
self.assertTupleEqual(tuple(a.shape), (100,))

# --- D: geometric off-by-one and integer dtype -------------------------------

def test_geometric_support_starts_at_one(self):
"""geometric is supported on the positive integers {1, 2, ...}."""
brainstate.random.seed(0)
a = np.asarray(brainstate.random.geometric(0.5, size=(5000,)))
self.assertGreaterEqual(int(a.min()), 1)
self.assertTrue(np.issubdtype(a.dtype, np.integer))

@pytest.mark.slow
def test_geometric_pmf_first_success(self):
"""P(k == 1) approaches ``p`` (NumPy convention)."""
brainstate.random.seed(0)
a = np.asarray(brainstate.random.geometric(0.35, size=(200000,)))
self.assertLess(abs((a == 1).mean() - 0.35), 0.01)

# --- E: randint_like default high with ndim>1 input --------------------------

def test_randint_like_multidim_default_high(self):
"""randint_like infers ``high`` from a multi-dimensional template."""
brainstate.random.seed(0)
template = jnp.array([[3, 7], [2, 9]])
a = brainstate.random.randint_like(template)
self.assertTupleEqual(tuple(a.shape), (2, 2))
self.assertTrue(bool(jnp.all(a >= 0)) and bool(jnp.all(a < 9)))

# --- F: chisquare accepts float and array df ---------------------------------

def test_chisquare_float_scalar_df(self):
"""chisquare accepts a non-integer scalar ``df``."""
brainstate.random.seed(0)
a = brainstate.random.chisquare(3.5)
self.assertTupleEqual(tuple(a.shape), ())
self.assertTrue(float(a) >= 0.0)

def test_chisquare_array_df_infers_shape(self):
"""chisquare with array ``df`` and ``size=None`` infers the shape from ``df``."""
brainstate.random.seed(0)
a = brainstate.random.chisquare(jnp.array([2.0, 3.0, 4.0]))
self.assertTupleEqual(tuple(a.shape), (3,))
self.assertTrue(bool(jnp.all(a >= 0.0)))

@pytest.mark.slow
def test_chisquare_mean_matches_df(self):
"""A large chi-square sample has mean ~ df."""
brainstate.random.seed(0)
a = np.asarray(brainstate.random.chisquare(7.0, size=(100000,)))
self.assertLess(abs(a.mean() - 7.0), 0.1)
93 changes: 72 additions & 21 deletions brainstate/random/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,9 @@ def standard_t(
):
df = _remove_unit_param('df', _check_py_seq(df))
if size is None:
size = u.math.shape(size) if size is not None else ()
# Match numpy: with no explicit size, draw one sample per ``df`` entry
# (a scalar ``df`` yields a scalar).
size = u.math.shape(df)
key = self.__get_key(key)
dtype = dtype or environ.dftype()
r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
Expand Down Expand Up @@ -867,14 +869,13 @@ def chisquare(
key = self.__get_key(key)
dtype = dtype or environ.dftype()
if size is None:
if jnp.ndim(df) == 0:
dist = jr.normal(key, (df,), dtype=dtype) ** 2
dist = dist.sum()
else:
raise NotImplementedError('Do not support non-scale "df" when "size" is None')
else:
dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
dist = dist.sum(axis=0)
size = u.math.shape(df)
# A chi-square distribution with ``df`` degrees of freedom is ``2 * Gamma(df/2)``.
# Using the gamma relation (rather than summing ``df`` squared normals) is valid
# for any positive real and array-valued ``df``, matches numpy, and mirrors the
# sibling ``t`` / ``noncentral_chisquare`` implementations.
df = u.math.asarray(df, dtype=dtype)
dist = 2.0 * jr.gamma(key, 0.5 * df, shape=_size2shape(size), dtype=dtype)
return dist

@named_scope(
Expand Down Expand Up @@ -911,10 +912,14 @@ def geometric(
if size is None:
size = u.math.shape(p)
key = self.__get_key(key)
dtype = dtype or environ.dftype()
u_ = jr.uniform(key, size, dtype)
r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
return r
float_dtype = environ.dftype()
u_ = jr.uniform(key, _size2shape(size), float_dtype)
# Inverse-CDF sampling. ``floor(log1p(-u) / log1p(-p))`` is supported on
# {0, 1, ...} (number of failures); the geometric distribution counts the
# number of trials until the first success and is supported on {1, 2, ...}
# with ``P(k == 1) == p``, so add one. Samples are integer-valued.
r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p)) + 1
return u.math.asarray(r, dtype=dtype or environ.ditype())

def _check_p2(self, p):
raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
Expand Down Expand Up @@ -1025,18 +1030,59 @@ def rayleigh(

@named_scope(
'brainstate/random',
static_argnums=(0, 1),
static_argnames=['size']
static_argnums=(0, 4, 6),
static_argnames=['dtype', 'size']
)
def triangular(
self,
left=0.0,
mode=0.5,
right=1.0,
size: Optional[Size] = None,
key: Optional[SeedOrKey] = None
key: Optional[SeedOrKey] = None,
dtype: DTypeLike = None
):
dtype = dtype or environ.dftype()
# ``left``/``mode``/``right`` share a single physical unit, inferred from
# whichever bound carries one (a plain bound is then interpreted in that shared
# unit). A compatible-but-different unit is converted; an incompatible one raises.
values = [
u.math.asarray(_check_py_seq(v), dtype=dtype)
for v in (left, mode, right)
]
unit = u.UNITLESS
for v in values:
q = u.Quantity(v)
if not q.is_unitless:
unit = q.unit
break

def _to_shared_unit(v):
q = u.Quantity(v)
return q.mantissa if q.is_unitless else q.in_unit(unit).mantissa

left, mode, right = (_to_shared_unit(v) for v in values)

if size is None:
size = lax.broadcast_shapes(
u.math.shape(left),
u.math.shape(mode),
u.math.shape(right),
)
size = _size2shape(size)
key = self.__get_key(key)
bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
r = 2 * bernoulli_samples - 1
return r
samples = jr.uniform(key, size, dtype=dtype)

# Inverse-CDF (quantile) transform of the triangular distribution. With
# ``fc = (mode - left) / (right - left)`` the cut point of the unit uniform,
# draws below ``fc`` map onto the rising edge and the rest onto the falling
# edge. ``where(span > 0, ...)`` guards the degenerate ``left == right`` case.
span = right - left
fc = jnp.where(span > 0, (mode - left) / jnp.where(span > 0, span, 1.0), 0.0)
rising = left + jnp.sqrt(samples * span * (mode - left))
falling = right - jnp.sqrt((1.0 - samples) * span * (right - mode))
r = jnp.where(samples < fc, rising, falling)
return u.maybe_decimal(r * unit)

@named_scope(
'brainstate/random',
Expand Down Expand Up @@ -1118,7 +1164,10 @@ def weibull_min(
random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
if scale_m is not None:
r = r / scale_m
# ``scale`` is the distribution scale parameter lambda: a 2-parameter
# Weibull draw is ``lambda * (-ln(1-U))**(1/a)`` (matching scipy
# ``weibull_min`` and the ``weibull`` docstring), so multiply by scale.
r = r * scale_m
return u.maybe_decimal(r * unit)

@named_scope(
Expand Down Expand Up @@ -1527,7 +1576,9 @@ def randint_like(
key: Optional[SeedOrKey] = None
):
if high is None:
high = max(input)
# ``max`` (the Python builtin) raises on multi-dimensional arrays; use an
# array reduction so any-rank templates work.
high = u.math.max(input)
return self.randint(low, high, size=u.math.shape(input), dtype=dtype, key=key)


Expand Down
20 changes: 11 additions & 9 deletions brainstate/random/_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,10 +861,11 @@ def test_chisquare_scalar_and_sized(self):
self.assertEqual(self.rs.chisquare(3).shape, ())
self.assertEqual(self.rs.chisquare(3, size=(4,)).shape, (4,))

def test_chisquare_nonscalar_df_requires_size(self):
"""chisquare with non-scalar df and no size is unsupported."""
with self.assertRaises(NotImplementedError):
self.rs.chisquare(jnp.array([2, 3]))
def test_chisquare_nonscalar_df_infers_shape(self):
"""chisquare with non-scalar df and no size infers the shape from df."""
arr = self.rs.chisquare(jnp.array([2.0, 3.0]))
self.assertEqual(arr.shape, (2,))
self.assertTrue((arr >= 0).all())

def test_dirichlet(self):
"""dirichlet rows sum to one over the simplex axis."""
Expand All @@ -873,10 +874,11 @@ def test_dirichlet(self):
np.testing.assert_allclose(np.asarray(arr).sum(axis=-1), 1.0, atol=1e-5)

def test_geometric(self):
"""geometric yields non-negative integer-valued samples."""
"""geometric yields integer-valued samples supported on {1, 2, ...}."""
arr = self.rs.geometric(0.5, size=(3,))
self.assertEqual(arr.shape, (3,))
self.assertTrue((arr >= 0).all())
self.assertTrue((arr >= 1).all())
self.assertTrue(np.issubdtype(np.asarray(arr).dtype, np.integer))

def test_multinomial(self):
"""multinomial counts sum to n across the category axis."""
Expand Down Expand Up @@ -932,10 +934,10 @@ def test_rayleigh(self):
self.assertTrue((arr >= 0).all())

def test_triangular(self):
"""triangular returns values in {-1, 1}."""
arr = self.rs.triangular(size=(50,))
"""triangular draws lie within [left, right] with the requested shape."""
arr = self.rs.triangular(-1.0, 0.0, 2.0, size=(50,))
self.assertEqual(arr.shape, (50,))
self.assertTrue(jnp.all((arr == -1) | (arr == 1)))
self.assertTrue(jnp.all((arr >= -1.0) & (arr <= 2.0)))

def test_vonmises(self):
"""vonmises returns angles within (-pi, pi]."""
Expand Down