From f25e7097d1bce1ddc4f29ba656de834badc4ea36 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 11 Jun 2026 23:13:37 +0800 Subject: [PATCH] fix(nn): resolve bugs and edge cases found across the nn module audit Systematic audit of brainstate.nn. Each fix is covered by a regression test (several previously-skipped "known bug" tests are now un-skipped and pass). Dropout/elementwise/activations: - AlphaDropout/FeatureAlphaDropout: correct self-normalizing affine constants. - Dropout2d/3d unbatched minimal-dim detection (per-element mask independence). - Softmin/Softmax/LogSoftmax default dim -> last axis; rrelu unit/int handling; soft_shrink zero-branch unit; channel-last docstrings. Linear/init/utils: - ScaledWSLinear mask/weight/bias shapes; AllToAll out>in padding. - TruncatedNormal default bounds; clip_grad_norm unitless-gradient note. Metrics/exp_euler: - Precision/Recall 'weighted' average + average validation; Welford int counter. - exp_euler diagonal-Jacobian docstring clarification. Transforms/param/hidata: - Softplus/NegSoftplus/Negative/Ordered: stable saturation-free forward and unit-safe stable inverse; Sigmoid/Affine log_abs_det_jacobian unit handling and per-batch shape; Affine zero-scale check on mantissa; Exp/Log/Positive saturation docs; HiData.clone/add/pop/replace preserve name. Module/common/collective_ops: - assign_state_values: pytree/Quantity values via tree.map; accept dotted-string and tuple keys. _filter_states dict branch iterates items(). vmap_call_all_fns rebuilt on vmap_new_states (fixes BatchTracer leak). Map.update no longer forwards spmd_axis_name to pmap2. Sequential empty-slice returns empty Sequential. in/out_size setters accept numpy scalars and 0-d arrays uniformly. Delay/dynamics/event_fixedprob: - Delay.max_time grows monotonically across registrations; take_aware_unit retrieval no longer crashes / double-applies unit; update_every made functional via a monotonic per-call write pointer and correct slot-spacing on time retrieval. FixedNumConn afferent_ratio mask respects seed; broken efferent_target='pre' path guarded with a clear NotImplementedError. --- brainstate/nn/_activations.py | 10 +- brainstate/nn/_activations_test.py | 21 +++ brainstate/nn/_collective_ops.py | 75 +++++---- brainstate/nn/_collective_ops_test.py | 59 ++++++- brainstate/nn/_common.py | 26 +-- brainstate/nn/_common_test.py | 41 ++--- brainstate/nn/_conv.py | 150 +++++++++++++---- brainstate/nn/_conv_test.py | 177 ++++++++++++++++++++ brainstate/nn/_delay.py | 81 ++++++--- brainstate/nn/_delay_test.py | 77 ++++++++- brainstate/nn/_dropout.py | 65 ++++++-- brainstate/nn/_dropout_test.py | 54 ++++++ brainstate/nn/_elementwise.py | 19 ++- brainstate/nn/_elementwise_test.py | 25 +++ brainstate/nn/_event_fixedprob.py | 26 ++- brainstate/nn/_event_fixedprob_test.py | 64 +++++--- brainstate/nn/_event_linear.py | 21 ++- brainstate/nn/_event_linear_test.py | 38 +++++ brainstate/nn/_exp_euler.py | 19 ++- brainstate/nn/_hidata.py | 8 +- brainstate/nn/_hidata_test.py | 31 ++++ brainstate/nn/_linear.py | 16 +- brainstate/nn/_linear_test.py | 23 +++ brainstate/nn/_metrics.py | 37 ++++- brainstate/nn/_metrics_test.py | 38 +++++ brainstate/nn/_module.py | 67 ++++++-- brainstate/nn/_module_test.py | 65 +++++++- brainstate/nn/_param.py | 16 +- brainstate/nn/_poolings.py | 219 +++++++++++++++---------- brainstate/nn/_poolings_test.py | 112 ++++++++++++- brainstate/nn/_regularization.py | 30 +++- brainstate/nn/_regularization_test.py | 41 +++++ brainstate/nn/_transform.py | 114 ++++++++++--- brainstate/nn/_transform_test.py | 71 ++++++++ brainstate/nn/_utils.py | 9 + brainstate/nn/init.py | 9 +- brainstate/nn/init_test.py | 8 +- 37 files changed, 1608 insertions(+), 354 deletions(-) diff --git a/brainstate/nn/_activations.py b/brainstate/nn/_activations.py index c64134c2..88202534 100644 --- a/brainstate/nn/_activations.py +++ b/brainstate/nn/_activations.py @@ -210,7 +210,7 @@ def soft_shrink(x: ArrayLike, lambd: float = 0.5) -> Union[jax.Array, u.Quantity u.math.where( x < -lambd, x + lambd, - u.Quantity(0., unit=u.get_unit(lambd)) + u.Quantity(0., unit=u.get_unit(x)) ) ) @@ -279,8 +279,12 @@ def rrelu(x: ArrayLike, lower: float = 0.125, upper: float = 0.3333333333333333) .. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations in Convolutional Network." arXiv:1505.00853 """ - a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype) - return u.math.where(u.get_mantissa(x) >= 0., x, a * x) + # Derive a float dtype robustly: ``x`` may be a python scalar or an integer + # array (neither of which has a usable float ``.dtype`` for sampling). + mantissa = u.get_mantissa(x) + sample_dtype = jax.numpy.result_type(mantissa, float) + a = random.uniform(lower, upper, size=u.math.shape(x), dtype=sample_dtype) + return u.math.where(mantissa >= 0., x, a * x) def hard_shrink(x: ArrayLike, lambd: float = 0.5) -> Union[jax.Array, u.Quantity]: diff --git a/brainstate/nn/_activations_test.py b/brainstate/nn/_activations_test.py index f6afabac..26778c1a 100644 --- a/brainstate/nn/_activations_test.py +++ b/brainstate/nn/_activations_test.py @@ -350,5 +350,26 @@ def f(hx, _): fwd() # doesn't crash +class TestActivationAuditRegressions(parameterized.TestCase): + """Regression tests for bugs found in the nn-module audit.""" + + def test_rrelu_accepts_python_scalar(self): + """A2: rrelu must not crash on a python float (which has no .dtype).""" + with brainstate.random.seed_context(0): + out = brainstate.nn.rrelu(1.0) + self.assertTrue(np.isfinite(np.asarray(out)).all()) + out_neg = brainstate.nn.rrelu(-1.0) + self.assertTrue(np.isfinite(np.asarray(out_neg)).all()) + + def test_rrelu_accepts_integer_array(self): + """A2: rrelu must not crash on an integer-typed array.""" + with brainstate.random.seed_context(0): + x = jnp.array([-2, -1, 0, 1, 2], dtype=jnp.int32) + out = np.asarray(brainstate.nn.rrelu(x)) + self.assertEqual(out.shape, (5,)) + # Non-negative entries pass through unchanged. + np.testing.assert_allclose(out[2:], np.array([0., 1., 2.]), rtol=1e-6) + + if __name__ == '__main__': absltest.main() diff --git a/brainstate/nn/_collective_ops.py b/brainstate/nn/_collective_ops.py index 156c65b8..f8a46d64 100644 --- a/brainstate/nn/_collective_ops.py +++ b/brainstate/nn/_collective_ops.py @@ -20,10 +20,9 @@ import jax -from brainstate._state import catch_new_states from brainstate._utils import set_module_as from brainstate.graph import nodes -from brainstate.transform import vmap, vmap_new_states +from brainstate.transform import vmap_new_states from brainstate.typing import Filter from ._module import Module @@ -274,24 +273,25 @@ def vmap_call_all_fns( if not isinstance(kwargs, Mapping): raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.') - @vmap(axis_size=axis_size) - def vmapped_fn(): - with catch_new_states(state_tag) as inner_catcher: - call_all_fns( - target, - fn_name=fn_name, - args=args, - kwargs=kwargs, - node_to_exclude=node_to_exclude, - fn_if_not_exist=fn_if_not_exist - ) - return inner_catcher.get_state_values() - - with catch_new_states(state_tag) as outer_catcher: - values = vmapped_fn() - states = outer_catcher.get_states() - for state, value in zip(states, values): - state.value = value + # Delegate to ``vmap_new_states``, the same transform used by + # ``vmap_init_all_states``. The previous hand-rolled implementation paired an + # inner ``catch_new_states`` (returning per-lane values) with an outer one + # (returning the State objects) and wrote the values back manually. Because the + # States were created *inside* the vmap trace, the write-back could commit a + # ``BatchTracer`` into ``state.value`` and the batched leading axis was never + # materialized, raising ``UnexpectedTracerError`` on later use. ``vmap_new_states`` + # handles new-state batching correctly in a single, well-tested pass. + def call_fn(): + call_all_fns( + target, + fn_name=fn_name, + args=args, + kwargs=kwargs, + node_to_exclude=node_to_exclude, + fn_if_not_exist=fn_if_not_exist, + ) + + vmap_new_states(call_fn, state_tag=state_tag, axis_size=axis_size)() return target @@ -631,14 +631,33 @@ def assign_state_values( # Get current module states variables = target.states() - keys1 = set(all_states.keys()) - keys2 = set(variables.keys()) - # Update matching states - for key in keys2.intersection(keys1): - variables[key].value = jax.numpy.asarray(all_states[key]) + # Normalize keys so both tuple paths (``('layer1', 'weight')`` — the form + # ``states()`` returns) and the documented dotted-string paths + # (``'layer1.weight'``) compare equal (M2). We keep the *original* key + # objects for the returned ``unexpected``/``missing`` lists so callers that + # rely on tuple paths keep working. + def _norm_key(k): + return '.'.join(str(p) for p in k) if isinstance(k, tuple) else str(k) + + var_by_norm = {_norm_key(k): k for k in variables.keys()} + incoming_by_norm = {_norm_key(k): k for k in all_states.keys()} + + # Update matching states. ``jax.tree.map`` preserves pytree-structured values + # (e.g. dict-valued states) and unit-carrying ``Quantity`` values, both of + # which ``jax.numpy.asarray`` would reject (M1). + for norm in set(var_by_norm).intersection(incoming_by_norm): + variables[var_by_norm[norm]].value = jax.tree.map( + jax.numpy.asarray, all_states[incoming_by_norm[norm]] + ) - # Return mismatched keys - unexpected_keys = sorted(keys1 - keys2) - missing_keys = sorted(keys2 - keys1) + # Return mismatched keys in their original (caller-supplied / module) form. + unexpected_keys = sorted( + (incoming_by_norm[n] for n in set(incoming_by_norm) - set(var_by_norm)), + key=_norm_key, + ) + missing_keys = sorted( + (var_by_norm[n] for n in set(var_by_norm) - set(incoming_by_norm)), + key=_norm_key, + ) return unexpected_keys, missing_keys diff --git a/brainstate/nn/_collective_ops_test.py b/brainstate/nn/_collective_ops_test.py index edf7513d..3eb75e42 100644 --- a/brainstate/nn/_collective_ops_test.py +++ b/brainstate/nn/_collective_ops_test.py @@ -827,10 +827,6 @@ def test_call_completes_and_runs_body(self): self.assertIs(returned, module) self.assertTrue(hasattr(module, 'w')) - @pytest.mark.skip(reason="BUG: vmap_call_all_fns leaks a JAX BatchTracer into " - "newly created state values; the batched value is not " - "committed (shape stays per-lane and later use raises " - "UnexpectedTracerError). See report.") def test_batched_init_creates_leading_axis(self): """Vmapped init_state should create a committed leading batch axis.""" module = EnsembleModule() @@ -840,8 +836,6 @@ def test_batched_init_creates_leading_axis(self): ) self.assertEqual(module.w.value.shape, (_testing.SMALL_BATCH, 3)) - @pytest.mark.skip(reason="BUG: vmap_call_all_fns leaks a JAX BatchTracer into " - "newly created state values. See report.") def test_batched_init_distinct_per_lane(self): """Each batch lane should receive an independent random initialization.""" module = EnsembleModule() @@ -851,8 +845,6 @@ def test_batched_init_distinct_per_lane(self): ) self.assertFalse(bool(jnp.allclose(module.w.value[0], module.w.value[1]))) - @pytest.mark.skip(reason="BUG: vmap_call_all_fns leaks a JAX BatchTracer into " - "newly created state values. See report.") def test_positional_single_arg_wrapped(self): """A single non-tuple positional argument is wrapped in a tuple.""" @@ -1011,3 +1003,54 @@ def test_multiple_dicts_merge_last_wins(self): {target_key: jnp.ones(3) * 7.0}, ) _testing.assert_allclose(net.states()[target_key].value, jnp.ones(3) * 7.0) + + # --- Audit regressions (M1: pytree/unit values; M2: dotted-string keys) --- + + def test_assign_dict_valued_state(self): + """M1: a state whose value is a dict must round-trip without crashing.""" + import brainunit as u + + class DictState(brainstate.nn.Module): + def init_state(self): + self.d = brainstate.ParamState({'a': jnp.ones(2), 'b': jnp.zeros(3)}) + + net = DictState() + brainstate.nn.init_all_states(net) + key = ('d',) + unexpected, missing = brainstate.nn.assign_state_values( + net, {key: {'a': jnp.ones(2) * 5.0, 'b': jnp.ones(3) * 2.0}} + ) + self.assertEqual(unexpected, []) + self.assertEqual(missing, []) + _testing.assert_allclose(net.states()[key].value['a'], jnp.ones(2) * 5.0) + _testing.assert_allclose(net.states()[key].value['b'], jnp.ones(3) * 2.0) + + def test_assign_quantity_valued_state(self): + """M1: a state whose value carries physical units must keep its unit.""" + import brainunit as u + + class QtyState(brainstate.nn.Module): + def init_state(self): + self.v = brainstate.ParamState(jnp.ones(3) * u.mV) + + net = QtyState() + brainstate.nn.init_all_states(net) + key = ('v',) + unexpected, missing = brainstate.nn.assign_state_values( + net, {key: jnp.ones(3) * 7.0 * u.mV} + ) + self.assertEqual((unexpected, missing), ([], [])) + restored = net.states()[key].value + self.assertEqual(u.get_unit(restored), u.mV) + _testing.assert_allclose(u.get_mantissa(restored), jnp.ones(3) * 7.0) + + def test_assign_accepts_dotted_string_keys(self): + """M2: dotted-string keys (as documented) must match tuple state paths.""" + net = self._make_net() + unexpected, missing = brainstate.nn.assign_state_values( + net, {'w': jnp.ones(3) * 3.0, 'b': jnp.ones(2) * 4.0} + ) + self.assertEqual(unexpected, []) + self.assertEqual(missing, []) + _testing.assert_allclose(net.states()[('w',)].value, jnp.ones(3) * 3.0) + _testing.assert_allclose(net.states()[('b',)].value, jnp.ones(2) * 4.0) diff --git a/brainstate/nn/_common.py b/brainstate/nn/_common.py index 58448c6e..8e435c4b 100644 --- a/brainstate/nn/_common.py +++ b/brainstate/nn/_common.py @@ -135,7 +135,10 @@ def _filter_states( filtered_states = None elif isinstance(filters, dict): in_states_filter = defaultdict(list) - for filter_, axis in filters: + # ``filters`` maps ``{filter: axis}`` (see Parameters). Iterate ``.items()`` + # so the filter and its axis are unpacked correctly; iterating the dict + # directly yields only keys and breaks the documented form. + for filter_, axis in filters.items(): assert isinstance(axis, int), 'The value of in_states must be the map axis, which should be an integer.' in_states_filter[axis].append(filter_) filtered_states = module.states(*in_states_filter.values()) @@ -564,24 +567,27 @@ def update(self, *args, **kwargs): raise ValueError( 'Map.update called before init_all_states. Please call init_all_states first.' ) + map_kwargs = dict( + in_axes=self.in_axes, + out_axes=self.out_axes, + axis_name=self.axis_name, + state_in_axes=self._call_state_axes, + state_out_axes=self._call_state_axes, + ) if self.behavior == 'vmap': map_fn = vmap2 + # ``spmd_axis_name`` is a vmap-only concept (nested vmap-over-pmap SPMD). + map_kwargs['spmd_axis_name'] = self.spmd_axis_name elif self.behavior == 'pmap': map_fn = pmap2 + # ``pmap2`` has no ``spmd_axis_name`` parameter; forwarding it raises a + # TypeError, so it is intentionally omitted here. else: raise ValueError( 'Invalid behavior specified. Must be "vmap" or "pmap".' ) - return map_fn( - self.module, - in_axes=self.in_axes, - out_axes=self.out_axes, - axis_name=self.axis_name, - spmd_axis_name=self.spmd_axis_name, - state_in_axes=self._call_state_axes, - state_out_axes=self._call_state_axes, - )(*args, **kwargs) + return map_fn(self.module, **map_kwargs)(*args, **kwargs) def map( self, diff --git a/brainstate/nn/_common_test.py b/brainstate/nn/_common_test.py index 1d26bb63..6b9c6b81 100644 --- a/brainstate/nn/_common_test.py +++ b/brainstate/nn/_common_test.py @@ -179,8 +179,8 @@ def test_update_with_explicit_context(self): class TestFilterStatesDict(unittest.TestCase): """Exercise the dictionary branch of ``_filter_states``.""" - def test_filter_states_dict_tuple_key(self): - """Map a (filter, axis) tuple key to its selected states by axis.""" + def test_filter_states_dict_multiple_axes(self): + """M3: documented ``{filter: axis}`` mapping selects states per axis.""" class M(Module): """Module carrying one ParamState and one ShortTermState.""" @@ -192,17 +192,18 @@ def __init__(self): self.s = brainstate.ShortTermState(jnp.zeros(3)) module = M() - # The implementation iterates over the dict's keys and unpacks each into - # (filter_, axis); supply a 2-tuple key so the unpack succeeds. - result = _filter_states(module, {(OfType(brainstate.ParamState), 0): 'ignored'}) + # Two filters mapping to two distinct axes (the documented form). + result = _filter_states( + module, + {OfType(brainstate.ParamState): 0, OfType(brainstate.ShortTermState): 1}, + ) self.assertIn(0, result) + self.assertIn(1, result) self.assertEqual(len(result[0]), 1) + self.assertEqual(len(result[1]), 1) - @pytest.mark.skip(reason="BUG: _filter_states dict branch unpacks dict keys " - "instead of items; documented {filter: axis} form " - "raises TypeError. See report.") def test_filter_states_dict_documented_form(self): - """Documented ``{filter: axis}`` mapping should select states by axis.""" + """M3: documented ``{filter: axis}`` mapping should select states by axis.""" class M(Module): """Module carrying one ParamState.""" @@ -215,6 +216,7 @@ def __init__(self): module = M() result = _filter_states(module, {OfType(brainstate.ParamState): 0}) self.assertIn(0, result) + self.assertEqual(len(result[0]), 1) class TestToPredicate(unittest.TestCase): @@ -439,26 +441,25 @@ def test_pmap_init_and_map_call(self): raise self.assertEqual(out.shape, (1, 2)) - def test_pmap_update_is_broken(self): - """pmap-mode ``update`` raises TypeError from an invalid pmap2 kwarg. + def test_pmap_update_runs(self): + """M5: pmap-mode ``update`` runs (``spmd_axis_name`` is not forwarded to pmap2). - ``Map.update`` always forwards ``spmd_axis_name`` to the map function, - but ``pmap2`` does not accept that keyword. This documents the bug while - keeping the suite green; see report. + Previously ``Map.update`` always forwarded ``spmd_axis_name`` to the map + function, but ``pmap2`` does not accept that keyword, so pmap-mode + ``update`` raised ``TypeError``. It must now run and produce the batched + output shape. """ m = Map(_ParamModule(), init_map_size=1, behavior='pmap') try: - # ``init_all_states`` already drives the pmap path; jax < 0.10 rejects - # ordered effects there, before the TypeError this test documents. + # jax < 0.10 rejects ordered effects inside pmap (raised as early as + # init_all_states), so the whole pmap interaction is guarded. m.init_all_states(din=3, dout=2) - m.update(jnp.ones((1, 3))) - except TypeError: - return # expected: pmap2 rejects the forwarded spmd_axis_name kwarg + out = m.update(jnp.ones((1, 3))) except ValueError as e: if 'Ordered effects' in str(e): self.skipTest(f'pmap on this JAX version rejects ordered effects: {e}') raise - self.fail('expected a TypeError from the invalid pmap2 kwarg, but no error was raised') + self.assertEqual(out.shape, (1, 2)) class TestMapCaller(unittest.TestCase): diff --git a/brainstate/nn/_conv.py b/brainstate/nn/_conv.py index bcd759b1..3be883a7 100644 --- a/brainstate/nn/_conv.py +++ b/brainstate/nn/_conv.py @@ -21,6 +21,7 @@ import brainunit as u import jax import jax.numpy as jnp +import numpy as np from brainstate._state import ParamState from brainstate.typing import ArrayLike @@ -150,6 +151,92 @@ def replicate( f"sequence of length {num_replicate}.") +def _bias_shape( + num_spatial_dims: int, + out_channels: int, + channel_first: bool, +) -> Tuple[int, ...]: + """Build a bias shape that broadcasts against the convolution output. + + The bias must broadcast over the channel axis of the output while being + a singleton along the batch and spatial axes. The channel axis position + depends on the data format. + + Parameters + ---------- + num_spatial_dims : int + The number of spatial dimensions of the convolution. + out_channels : int + The number of output channels. + channel_first : bool + + - If True, the output is ``[B, C, *spatial]`` and the bias broadcasts + over the channel axis at position 1: ``(1, C) + (1,) * num_spatial_dims``. + - If False, the output is ``[B, *spatial, C]`` and the bias broadcasts + over the trailing channel axis: ``(1,) * num_spatial_dims + (C,)``. + + Returns + ------- + tuple of int + The bias shape, excluding the batch dimension (which broadcasts). + + Examples + -------- + .. code-block:: python + + >>> _bias_shape(2, 4, channel_first=False) + (1, 1, 4) + >>> _bias_shape(2, 4, channel_first=True) + (1, 4, 1, 1) + """ + if channel_first: + return (1, out_channels) + (1,) * num_spatial_dims + return (1,) * num_spatial_dims + (out_channels,) + + +def _conv_transpose_padding(k: int, s: int, padding: str) -> Tuple[int, int]: + """Compute the (before, after) padding for one dimension of a transposed conv. + + This mirrors ``jax.lax.conv_transpose``'s internal padding computation so that a + transposed convolution expressed via ``conv_general_dilated`` (with + ``lhs_dilation`` equal to the stride) produces exactly the same output size as + ``jax.lax.conv_transpose`` for both 'SAME' and 'VALID' modes at any stride. + + Parameters + ---------- + k : int + The effective kernel size along the dimension (already accounting for any + right-hand-side dilation). + s : int + The stride (i.e. the input dilation) along the dimension. + padding : {'SAME', 'VALID'} + The padding mode of the corresponding forward convolution. + + Returns + ------- + tuple of int + The ``(pad_before, pad_after)`` padding to feed to ``conv_general_dilated``. + + Notes + ----- + For 'VALID' padding the output length is ``(in - 1) * s + k``; for 'SAME' it is + ``in * s``. The asymmetric split matches Keras/TensorFlow ``Conv?DTranspose``. + """ + if padding == 'SAME': + pad_len = k + s - 2 + if s > k - 1: + pad_a = k - 1 + else: + pad_a = int(np.ceil(pad_len / 2)) + elif padding == 'VALID': + pad_len = k + s - 2 + max(k - s, 0) + pad_a = k - 1 + else: + raise ValueError(f"Invalid padding mode: {padding}") + pad_b = pad_len - pad_a + return pad_a, pad_b + + class _BaseConv(Module): # the number of spatial dimensions num_spatial_dims: int @@ -303,7 +390,7 @@ def __init__( weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False) params = dict(weight=weight) if self.b_initializer is not None: - bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,) + bias_shape = _bias_shape(self.num_spatial_dims, self.out_channels, self.channel_first) bias = init.param(self.b_initializer, bias_shape, allow_none=True) params['bias'] = bias @@ -808,7 +895,7 @@ def __init__( weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False) params = dict(weight=weight) if self.b_initializer is not None: - bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,) + bias_shape = _bias_shape(self.num_spatial_dims, self.out_channels, self.channel_first) bias = init.param(self.b_initializer, bias_shape, allow_none=True) params['bias'] = bias @@ -1430,34 +1517,17 @@ def __init__( ) # the padding parameter - # For transposed convolution, string padding needs to be converted to explicit padding - # when using lhs_dilation (stride) > 1 + # For transposed convolution we run ``conv_general_dilated`` with + # ``lhs_dilation`` equal to the stride. The explicit padding that yields the + # correct transposed-conv output size differs from a forward conv and is + # computed (per dimension) in ``_conv_op`` via ``_conv_transpose_padding`` so + # that 'SAME'/'VALID' match ``jax.lax.conv_transpose`` at *all* strides + # (including stride == 1). if isinstance(padding, str): assert padding in ['SAME', 'VALID'] self.padding_mode = padding - # Compute explicit padding for transposed convolution - if max(self.stride) > 1: - # For transposed conv with stride, compute padding to achieve desired output size - spatial_in_size = self.in_size[:-1] if not self.channel_first else self.in_size[1:] - if padding == 'SAME': - # For SAME padding with transposed conv: output_size = input_size * stride - # Compute required padding to achieve this - explicit_padding = [] - for i, (k, s, in_dim) in enumerate(zip(self.kernel_size, self.stride, spatial_in_size)): - # Desired output size - out_dim = in_dim * s - # Calculate total padding needed - # For transposed conv: out = (in - 1) * stride + kernel - 2 * pad - # Solving for pad: pad = (kernel + (in-1) * stride - out) // 2 - total_pad = max(k + (in_dim - 1) * s - out_dim, 0) - pad_left = total_pad // 2 - pad_right = total_pad - pad_left - explicit_padding.append((pad_left, pad_right)) - padding = tuple(explicit_padding) - else: # 'VALID' - # For VALID padding: no padding - padding = tuple((0, 0) for _ in range(self.num_spatial_dims)) - # If stride is 1, keep string padding + # Keep the mode string; the explicit per-dimension padding is resolved + # lazily in ``_conv_op`` (it depends on the effective, dilated kernel size). elif isinstance(padding, int): self.padding_mode = 'explicit' padding = tuple((padding, padding) for _ in range(self.num_spatial_dims)) @@ -1498,7 +1568,7 @@ def __init__( weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False) params = dict(weight=weight) if self.b_initializer is not None: - bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,) + bias_shape = _bias_shape(self.num_spatial_dims, self.out_channels, self.channel_first) bias = init.param(self.b_initializer, bias_shape, allow_none=True) params['bias'] = bias @@ -1515,10 +1585,32 @@ def __init__( y_shape = abstract_y.shape[1:] self.out_size = y_shape + def _resolve_padding(self): + """Resolve the explicit per-dimension padding for the transposed convolution. + + For 'SAME'/'VALID' modes the padding is derived from the (dilated) kernel size + and stride via :func:`_conv_transpose_padding`, matching + ``jax.lax.conv_transpose``. For explicit (integer/tuple) padding the stored + value is used directly. + """ + if self.padding in ('SAME', 'VALID'): + pads = [] + for k, s, d in zip(self.kernel_size, self.stride, self.rhs_dilation): + # effective kernel size after right-hand-side dilation + eff_k = (k - 1) * d + 1 + pads.append(_conv_transpose_padding(eff_k, s, self.padding)) + return tuple(pads) + return self.padding + def _conv_op(self, x, params): w = params['weight'] if self.w_mask is not None: w = w * self.w_mask + # The transposed-convolution kernel layout is + # (spatial..., out_channels, in_channels // groups). To match + # ``jax.lax.conv_transpose`` (which flips the kernel's spatial axes), flip the + # spatial axes here before feeding ``conv_general_dilated``. + w = jnp.flip(w, axis=tuple(range(self.num_spatial_dims))) # For transposed convolution: # - window_strides should be (1,1,...) - no striding in the conv operation # - lhs_dilation should be the stride - this creates the upsampling effect @@ -1527,7 +1619,7 @@ def _conv_op(self, x, params): lhs=x, rhs=w, window_strides=window_strides, - padding=self.padding, + padding=self._resolve_padding(), lhs_dilation=self.stride, # For transpose conv, use stride as lhs_dilation rhs_dilation=self.rhs_dilation, feature_group_count=self.groups, diff --git a/brainstate/nn/_conv_test.py b/brainstate/nn/_conv_test.py index 3ae18349..4057270b 100644 --- a/brainstate/nn/_conv_test.py +++ b/brainstate/nn/_conv_test.py @@ -17,6 +17,7 @@ import unittest +import jax import jax.numpy as jnp import numpy as np import pytest @@ -1101,3 +1102,179 @@ def test_conv_transpose2d_jit_equal(self): conv = brainstate.nn.ConvTranspose2d(in_size=(8, 8, 16), out_channels=8, kernel_size=3) x = brainstate.random.randn(2, 8, 8, 16) _testing.assert_jit_equal(lambda inp: conv(inp), x) + + +class TestChannelFirstBias(unittest.TestCase): + """Regression tests for C1: channel_first=True together with a non-zero bias. + + Previously the bias was always built channels-last as ``(1, ..., 1, C)`` which + failed to broadcast against a ``[B, C, *spatial]`` channels-first output. + """ + + def test_conv2d_channel_first_with_bias(self): + """Conv2d with channel_first=True and a constant bias produces [B, C, H, W].""" + conv = brainstate.nn.Conv2d( + in_size=(3, 8, 8), + out_channels=4, + kernel_size=3, + channel_first=True, + b_init=brainstate.init.Constant(1.0), + ) + x = jnp.ones((2, 3, 8, 8)) + y = conv(x) + self.assertEqual(y.shape, (2, 4, 8, 8)) + self.assertIn('bias', conv.weight.value) + + def test_conv1d_channel_first_with_bias(self): + """Conv1d with channel_first=True and a constant bias produces [B, C, L].""" + conv = brainstate.nn.Conv1d( + in_size=(8, 50), + out_channels=16, + kernel_size=3, + channel_first=True, + b_init=brainstate.init.Constant(1.0), + ) + x = jnp.ones((2, 8, 50)) + y = conv(x) + self.assertEqual(y.shape, (2, 16, 50)) + + def test_conv3d_channel_first_with_bias(self): + """Conv3d with channel_first=True and a constant bias produces [B, C, H, W, D].""" + conv = brainstate.nn.Conv3d( + in_size=(2, 8, 8, 8), + out_channels=4, + kernel_size=3, + channel_first=True, + b_init=brainstate.init.Constant(1.0), + ) + x = jnp.ones((2, 2, 8, 8, 8)) + y = conv(x) + self.assertEqual(y.shape, (2, 4, 8, 8, 8)) + + def test_scaled_ws_conv2d_channel_first_with_bias(self): + """ScaledWSConv2d with channel_first=True and a constant bias produces [B, C, H, W].""" + conv = brainstate.nn.ScaledWSConv2d( + in_size=(3, 8, 8), + out_channels=4, + kernel_size=3, + channel_first=True, + b_init=brainstate.init.Constant(1.0), + ) + x = jnp.ones((2, 3, 8, 8)) + y = conv(x) + self.assertEqual(y.shape, (2, 4, 8, 8)) + self.assertIn('bias', conv.weight.value) + + def test_conv_transpose2d_channel_first_with_bias(self): + """ConvTranspose2d with channel_first=True and a constant bias produces [B, C, H, W].""" + conv = brainstate.nn.ConvTranspose2d( + in_size=(16, 8, 8), + out_channels=8, + kernel_size=3, + stride=2, + channel_first=True, + b_init=brainstate.init.Constant(1.0), + ) + x = jnp.ones((2, 16, 8, 8)) + y = conv(x) + self.assertEqual(y.shape[0], 2) + self.assertEqual(y.shape[1], 8) # channel axis at position 1 + self.assertIn('bias', conv.weight.value) + + def test_conv_transpose1d_channel_first_with_bias(self): + """ConvTranspose1d with channel_first=True and a constant bias produces [B, C, L].""" + conv = brainstate.nn.ConvTranspose1d( + in_size=(16, 28), + out_channels=8, + kernel_size=4, + stride=2, + channel_first=True, + b_init=brainstate.init.Constant(1.0), + ) + x = jnp.ones((2, 16, 28)) + y = conv(x) + self.assertEqual(y.shape[0], 2) + self.assertEqual(y.shape[1], 8) + + +class TestConvTransposeShapeVsJax(unittest.TestCase): + """Regression tests for C2/C3: transposed-conv output sizes must match + ``jax.lax.conv_transpose`` for both 'SAME' and 'VALID' at all strides. + """ + + def _jax_reference(self, x, w, stride, padding): + """Channels-last 1D reference via jax.lax.conv_transpose (WOI kernel).""" + return jax.lax.conv_transpose( + x, w, + strides=(stride,), + padding=padding, + dimension_numbers=('NWC', 'WIO', 'NWC'), + transpose_kernel=True, + ) + + def test_exact_output_shape_matches_jax(self): + """Output shape equals jax.lax.conv_transpose for k in {2,3,4,5}, s in {1,2,3}.""" + in_len = 7 + in_ch = 2 + out_ch = 3 + for k in (2, 3, 4, 5): + for s in (1, 2, 3): + for padding in ('SAME', 'VALID'): + conv = brainstate.nn.ConvTranspose1d( + in_size=(in_len, in_ch), + out_channels=out_ch, + kernel_size=k, + stride=s, + padding=padding, + ) + x = jnp.ones((1, in_len, in_ch)) + y = conv(x) + # Reference kernel layout for jax: (K, C_out, C_in) + w_ref = jnp.ones((k, out_ch, in_ch)) + y_ref = self._jax_reference(x, w_ref, s, padding) + self.assertEqual( + y.shape, y_ref.shape, + f"shape mismatch k={k} s={s} pad={padding}: " + f"got {y.shape}, jax {y_ref.shape}" + ) + + def test_numeric_equivalence_to_jax(self): + """No-bias, single-group output matches jax.lax.conv_transpose numerically.""" + in_len = 9 + in_ch = 2 + out_ch = 4 + for k in (2, 3, 4, 5): + for s in (1, 2, 3): + for padding in ('SAME', 'VALID'): + conv = brainstate.nn.ConvTranspose1d( + in_size=(in_len, in_ch), + out_channels=out_ch, + kernel_size=k, + stride=s, + padding=padding, + ) + x = brainstate.random.randn(1, in_len, in_ch) + y = conv(x) + # brainstate transposed kernel layout is (K, C_out, C_in) + w = conv.weight.value['weight'] + y_ref = self._jax_reference(x, w, s, padding) + np.testing.assert_allclose( + np.asarray(y), np.asarray(y_ref), atol=1e-4, rtol=1e-4, + err_msg=f"value mismatch k={k} s={s} pad={padding}" + ) + + def test_out_size_attribute_matches_jax(self): + """The cached out_size attribute matches jax.lax.conv_transpose.""" + for k in (2, 3, 4, 5): + for s in (1, 2, 3): + for padding in ('SAME', 'VALID'): + conv = brainstate.nn.ConvTranspose1d( + in_size=(7, 2), + out_channels=3, + kernel_size=k, + stride=s, + padding=padding, + ) + w_ref = jnp.ones((k, 3, 2)) + y_ref = self._jax_reference(jnp.ones((1, 7, 2)), w_ref, s, padding) + self.assertEqual(conv.out_size, y_ref.shape[1:]) diff --git a/brainstate/nn/_delay.py b/brainstate/nn/_delay.py index 6b548fc1..1662b2d7 100644 --- a/brainstate/nn/_delay.py +++ b/brainstate/nn/_delay.py @@ -670,7 +670,10 @@ def register_delay(self, *time_and_idx): environ.get_dt() if self.update_every is None else self.update_every ) max_delay_step = jnp.max(delay_step) - self.max_time = u.math.max(time) + # Take the maximum against the existing ``max_time`` so registering a + # shorter delay after a longer one does not shrink the buffer window. + # This mirrors ``max_length`` below, which already only grows. + self.max_time = u.math.maximum(self.max_time, u.math.max(time)) # delay variable if self.max_length <= max_delay_step + 1: @@ -778,9 +781,13 @@ def _check_delay(delay_len): # unified ring buffer method using write_ptr with jax.ensure_compile_time_eval(): - # Use write_ptr instead of environ.get(environ.I) - # Note: write_ptr points to the NEXT write position, so current position is write_ptr - 1 - current_ptr = self.write_ptr.value // self.update_every_step - 1 + # Use write_ptr instead of environ.get(environ.I). + # ``write_ptr`` is a monotonic per-call counter; the most recently + # written slot is ``(write_ptr - 1) // update_every_step``. (For the + # default ``update_every_step == 1`` this equals the previous + # ``write_ptr // step - 1`` form, but it stays correct under frequency + # control, where many calls map to the same slot.) + current_ptr = (self.write_ptr.value - 1) // self.update_every_step di = current_ptr - delay_step delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32) delay_idx = jax.lax.stop_gradient(delay_idx) @@ -794,10 +801,19 @@ def _check_delay(delay_len): if self._unit is None: return jax.tree.map(lambda a: a[indices], self.history.value) else: + # ``self._unit`` holds a ``brainunit.Unit`` per leaf (recorded with + # ``is_leaf=is_quantity``). The history buffer may itself be a + # ``Quantity`` (when ``target_info`` carried units) whose pytree node + # type differs from ``Unit`` — zipping the two without ``is_leaf`` + # raises "Custom node type mismatch". Treat each history leaf + # (Quantity or plain array) as a single leaf, strip any unit it + # already carries, then re-apply the recorded unit. This both fixes + # the crash and avoids double-counting the unit (mV * mV). return jax.tree.map( - lambda hist, unit: u.maybe_decimal(hist[indices] * unit), + lambda hist, unit: u.maybe_decimal(u.get_mantissa(hist[indices]) * unit), self.history.value, - self._unit + self._unit, + is_leaf=u.math.is_quantity, ) def retrieve_at_time(self, delay_time, *indices) -> PyTree: @@ -842,7 +858,12 @@ def _check_delay(t_now, t_delay): with jax.ensure_compile_time_eval(): diff = current_time - delay_time - float_time_step = diff / dt + # Buffer slots are spaced ``update_every`` apart (one slot per + # ``update_every_step`` calls), so convert the continuous time offset + # into *slot* units using that spacing. Using ``dt`` here would + # over-count by ``update_every_step`` whenever frequency control is on. + slot_dt = dt if self.update_every is None else self.update_every + float_time_step = diff / slot_dt # Use interpolation methods that call retrieve_at_step for bounds checking if ( @@ -867,7 +888,7 @@ def _check_delay(t_now, t_delay): else: # For other interpolation methods (cubic, hermite, polynomial), use the registry # Calculate the buffer position accounting for ring buffer - current_ptr = self.write_ptr.value // self.update_every_step - 1 + current_ptr = (self.write_ptr.value - 1) // self.update_every_step float_buffer_idx = current_ptr - float_time_step if isinstance(self.interp_method, str): @@ -880,7 +901,12 @@ def _check_delay(t_now, t_delay): return interp_func(self.history.value, indices, float_buffer_idx, self.max_length) def _write_to_buffer(self, value: PyTree) -> None: - """Write a value to the ring buffer at current write_ptr position.""" + """Write a value into the ring buffer slot for the current write pointer. + + The write pointer is *not* advanced here; advancement happens once per + ``update`` call in :meth:`_advance_write_ptr` so the buffer cadence is a + function of the number of calls, not of how many writes have occurred. + """ idx = jnp.asarray(self.write_ptr.value // self.update_every_step, dtype=environ.dutype()) idx = jax.lax.stop_gradient(idx) self.history.value = jax.tree.map( @@ -888,22 +914,18 @@ def _write_to_buffer(self, value: PyTree) -> None: self.history.value, value ) - self.write_ptr.value = (self.write_ptr.value + 1) % self.max_length - - def _frequency_controlled_update(self, current: PyTree) -> None: - """Handle frequency-controlled updates with different strategies.""" - # Update time accumulator - should_update = self.write_ptr.value % self.update_every_step == 0 + def _advance_write_ptr(self) -> None: + """Advance the monotonic per-call write pointer. - def do_nothing(): - pass - - # Hold: Only write when threshold crossed - def write_and_reset(): - self._write_to_buffer(current) - - cond(should_update, write_and_reset, do_nothing) + ``write_ptr`` counts ``update`` calls and wraps at + ``max_length * update_every_step`` so that ``write_ptr // update_every_step`` + cycles through the ``max_length`` buffer slots. For the default + (``update_every is None`` ⇒ ``update_every_step == 1``) this reduces to the + ordinary ``(write_ptr + 1) % max_length`` ring advance. + """ + period = self.max_length * self.update_every_step + self.write_ptr.value = (self.write_ptr.value + 1) % period def update(self, current: PyTree) -> None: """ @@ -921,13 +943,18 @@ def _update_impl(self, current: PyTree) -> None: if self.take_aware_unit and self._unit is None: self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity) - # Check if frequency control is enabled if self.update_every is None: - # Default: update every call + # Default: write on every call. self._write_to_buffer(current) else: - # Frequency-controlled update - self._frequency_controlled_update(current) + # Frequency-controlled: write only when the per-call pointer lands on + # an ``update_every_step`` boundary. ``should_write`` keys off the + # monotonic call pointer (advanced below every call), so the cadence + # no longer stalls after the first write. + should_write = (self.write_ptr.value % self.update_every_step) == 0 + cond(should_write, lambda: self._write_to_buffer(current), lambda: None) + # Advance the per-call pointer exactly once per update, in both modes. + self._advance_write_ptr() class StateWithDelay(Delay): diff --git a/brainstate/nn/_delay_test.py b/brainstate/nn/_delay_test.py index 4bdf8d1b..036c82ba 100644 --- a/brainstate/nn/_delay_test.py +++ b/brainstate/nn/_delay_test.py @@ -42,6 +42,23 @@ def test_delay1(self): with self.assertRaises(KeyError): delay.register_entry('c', 10.) + def test_max_time_is_monotonic_across_registrations(self): + """D2: registering a shorter delay after a longer one keeps the larger max_time.""" + delay = brainstate.nn.Delay(jnp.ones((1,))) + delay.register_entry('big', 5.0) + big_max_time = float(delay.max_time) + delay.register_entry('small', 1.0) + # max_time must remain the maximum (5.0), not be clobbered to 1.0. + self.assertEqual(float(delay.max_time), big_max_time) + self.assertEqual(float(delay.max_time), 5.0) + + def test_max_time_grows_for_larger_later_delay(self): + """D2: a longer delay registered later raises max_time accordingly.""" + delay = brainstate.nn.Delay(jnp.ones((1,))) + delay.register_entry('small', 1.0) + delay.register_entry('big', 7.0) + self.assertEqual(float(delay.max_time), 7.0) + def test_rotation_delay(self): rotation_delay = brainstate.nn.Delay(jnp.ones((1,))) t0 = 0. @@ -739,14 +756,48 @@ def test_update_every_below_dt_raises(self): brainstate.nn.Delay(jnp.zeros((2,)), time=1.0, update_every=0.05) def test_frequency_controlled_hold_update(self): - """With ``update_every`` set, the buffer only advances on threshold crossings.""" + """D1: with ``update_every`` set, the per-call pointer keeps advancing. + + The write pointer counts calls and wraps at ``max_length * update_every_step``; + a buffer write happens every ``update_every_step`` calls. + """ delay = brainstate.nn.Delay(jnp.zeros((2,)), time=1.0, update_every=0.5) delay.init_state() - # update_every_step == 5; only every 5th call writes to the buffer. + period = delay.max_length * delay.update_every_step # 3 * 5 == 15 for i in range(20): delay.update(jnp.ones((2,)) * i) - # write_ptr advances once per threshold crossing (20 / 5 == 4 writes). - self.assertEqual(int(delay.write_ptr.value), 4 % delay.max_length) + # The pointer is a monotonic call counter modulo the period (20 % 15 == 5), + # NOT frozen at 1 (the old chicken-and-egg bug). + self.assertEqual(int(delay.write_ptr.value), 20 % period) + + def test_update_every_buffer_advances_each_window(self): + """D1: the buffer actually records distinct values across windows. + + With ``update_every_step == 5`` the writes at calls 0, 5, 10 land in + successive ring slots, so the buffer is not stuck holding the first value. + """ + delay = brainstate.nn.Delay(jnp.zeros((1,)), time=1.0, update_every=0.5) + delay.init_state() + for i in range(11): + delay.update(jnp.ones((1,)) * float(i)) + # Writes occurred at i = 0, 5, 10 -> the three ring slots hold 0, 5, 10. + slots = set(float(v) for v in jnp.ravel(delay.history.value).tolist()) + self.assertEqual(slots, {0.0, 5.0, 10.0}) + + def test_retrieve_at_time_uses_update_every_spacing(self): + """D5: continuous-time retrieval converts delay to slots via ``update_every``.""" + delay = brainstate.nn.Delay(jnp.zeros((1,)), time=1.0, + update_every=0.5, interpolation='round') + delay.init_state() + for i in range(11): + delay.update(jnp.ones((1,)) * float(i)) + # Most recent buffer slot holds value 10 (written at call 10). A zero delay + # should return it; a one-window (0.5) delay should return the prior slot (5). + with brainstate.environ.context(t=1.0): + latest = delay.retrieve_at_time(1.0) # diff = 0 -> newest slot + one_back = delay.retrieve_at_time(0.5) # diff = 0.5 -> one slot back + self.assertTrue(jnp.allclose(latest, jnp.ones((1,)) * 10.0)) + self.assertTrue(jnp.allclose(one_back, jnp.ones((1,)) * 5.0)) def test_update_frequency_none_updates_every_call(self): """Without ``update_every`` the buffer advances on every update.""" @@ -845,10 +896,8 @@ def test_take_aware_unit_update_records_unit(self): delay.update(jnp.ones((2,)) * i * u.mV) self.assertEqual(delay._unit, u.mV) - @pytest.mark.skip(reason="BUG: take_aware_unit retrieval fails with pytree node mismatch " - "(Quantity history vs Unit _unit) in _retrieve_at_step_impl") def test_take_aware_unit_retrieval(self): - """A unit-aware delay should return values carrying the original unit.""" + """D3: a unit-aware delay returns values carrying the original unit (no crash).""" import brainunit as u delay = brainstate.nn.Delay(jnp.zeros((2,)) * u.mV, time=1.0, take_aware_unit=True) delay.register_entry('e', 0.5) @@ -858,6 +907,20 @@ def test_take_aware_unit_retrieval(self): result = delay.at('e') self.assertEqual(u.get_unit(result), u.mV) + def test_take_aware_unit_retrieval_value_not_double_unit(self): + """D3: the retrieved magnitude is not double-counted (mV, not mV**2).""" + import brainunit as u + delay = brainstate.nn.Delay(jnp.zeros((2,)) * u.mV, time=1.0, take_aware_unit=True) + delay.register_entry('latest', 0.0) + delay.init_state() + for i in range(5): + delay.update(jnp.ones((2,)) * float(i) * u.mV) + result = delay.at('latest') + # Unit is exactly mV (mantissa unchanged), confirming no mV**2 double-apply. + self.assertEqual(u.get_unit(result), u.mV) + # The most recently written value (i = 4) is returned at zero delay. + self.assertTrue(jnp.allclose(u.get_mantissa(result), jnp.ones((2,)) * 4.0)) + class TestDelayMethodBackwardCompat(unittest.TestCase): """Tests for the deprecated delay_method argument handling.""" diff --git a/brainstate/nn/_dropout.py b/brainstate/nn/_dropout.py index 1cd3beff..54481855 100644 --- a/brainstate/nn/_dropout.py +++ b/brainstate/nn/_dropout.py @@ -40,8 +40,8 @@ class Dropout(ElementWiseBlock): """A layer that stochastically ignores a subset of inputs each training step. - In training, to compensate for the fraction of input values dropped (`rate`), - all surviving values are multiplied by `1 / (1 - rate)`. + In training, to compensate for the fraction of input values dropped, all surviving + values are multiplied by ``1 / prob``, where ``prob`` is the *keep* probability. This layer is active only during training (``mode=brainstate.mixin.Training``). In other circumstances it is a no-op. @@ -128,7 +128,7 @@ def __call__(self, x): if inp_dim not in (self.minimal_dim, self.minimal_dim + 1): raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, " f"but received a {inp_dim}D input. {self._get_msg(x)}") - is_not_batched = self.minimal_dim + is_not_batched = (inp_dim == self.minimal_dim) if is_not_batched: channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis) mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)] @@ -186,9 +186,16 @@ class Dropout1d(_DropoutNd): Notes ----- - Input shape: :math:`(N, C, L)` or :math:`(C, L)`. + With the default ``channel_axis=-1`` (channel-last convention used throughout + brainstate), the channel is the last axis. - Output shape: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input). + Input shape: :math:`(N, L, C)` or :math:`(L, C)`. + + Output shape: :math:`(N, L, C)` or :math:`(L, C)` (same shape as input). + + A whole channel (a 1D feature map) is dropped together, and the mask is drawn + independently per batch element. Pass ``channel_axis`` to select a different + channel dimension. References ---------- @@ -246,9 +253,16 @@ class Dropout2d(_DropoutNd): Notes ----- - Input shape: :math:`(N, C, H, W)` or :math:`(C, H, W)`. + With the default ``channel_axis=-1`` (channel-last convention used throughout + brainstate), the channel is the last axis. + + Input shape: :math:`(N, H, W, C)` or :math:`(H, W, C)`. - Output shape: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input). + Output shape: :math:`(N, H, W, C)` or :math:`(H, W, C)` (same shape as input). + + A whole channel (a 2D feature map) is dropped together, and the mask is drawn + independently per batch element. Pass ``channel_axis`` to select a different + channel dimension. References ---------- @@ -306,9 +320,16 @@ class Dropout3d(_DropoutNd): Notes ----- - Input shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`. + With the default ``channel_axis=-1`` (channel-last convention used throughout + brainstate), the channel is the last axis. + + Input shape: :math:`(N, D, H, W, C)` or :math:`(D, H, W, C)`. + + Output shape: :math:`(N, D, H, W, C)` or :math:`(D, H, W, C)` (same shape as input). - Output shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input). + A whole channel (a 3D feature map) is dropped together, and the mask is drawn + independently per batch element. Pass ``channel_axis`` to select a different + channel dimension. References ---------- @@ -401,9 +422,14 @@ def __init__( alpha = -1.7580993408473766 self.alpha = alpha - # Affine transformation parameters to maintain mean and variance - self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5 - self.b = -self.a * alpha * prob + # Affine transformation parameters to maintain mean and variance. + # ``prob`` is the *keep* probability; ``q`` is the *drop* probability. + # The self-normalizing affine is a = (keep * (1 + q * alpha**2))**-0.5, + # b = -a * q * alpha (these only coincide with the keep/drop-swapped form + # at prob == 0.5). + keep, q = prob, 1. - prob + self.a = (keep * (1 + q * alpha ** 2)) ** -0.5 + self.b = -self.a * q * alpha def __call__(self, x): dtype = u.math.get_dtype(x) @@ -454,7 +480,8 @@ class FeatureAlphaDropout(ElementWiseBlock): Notes ----- - Input shape: :math:`(N, C, *)` where C is the channel dimension. + With the default ``channel_axis=-1`` (channel-last convention used throughout + brainstate), the channel is the last axis, e.g. input shape :math:`(N, *, C)`. Output shape: Same shape as input. @@ -495,9 +522,11 @@ def __init__( alpha = -1.7580993408473766 self.alpha = alpha - # Affine transformation parameters to maintain mean and variance - self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5 - self.b = -self.a * alpha * prob + # Affine transformation parameters to maintain mean and variance. + # ``prob`` is the *keep* probability; ``q`` is the *drop* probability. + keep, q = prob, 1. - prob + self.a = (keep * (1 + q * alpha ** 2)) ** -0.5 + self.b = -self.a * q * alpha def __call__(self, x): dtype = u.math.get_dtype(x) @@ -523,8 +552,8 @@ def __call__(self, x): class DropoutFixed(ElementWiseBlock): """A dropout layer with a fixed dropout mask along the time axis. - In training, to compensate for the fraction of input values dropped, - all surviving values are multiplied by `1 / (1 - prob)`. + In training, to compensate for the fraction of input values dropped, all surviving + values are multiplied by ``1 / prob``, where ``prob`` is the *keep* probability. This layer is active only during training (``mode=brainstate.mixin.Training``). In other circumstances it is a no-op. diff --git a/brainstate/nn/_dropout_test.py b/brainstate/nn/_dropout_test.py index 2ed4e4c5..8c075b7e 100644 --- a/brainstate/nn/_dropout_test.py +++ b/brainstate/nn/_dropout_test.py @@ -476,5 +476,59 @@ def test_dropoutfixed_various_probs(self, prob): self.assertEqual(input_data.shape, output_data.shape) +class TestDropoutAuditRegressions(parameterized.TestCase): + """Regression tests for bugs found in the nn-module audit.""" + + @parameterized.parameters(0.2, 0.5, 0.8) + def test_alpha_dropout_affine_constants(self, prob): + """N1: AlphaDropout affine a/b must use keep/drop probs correctly (not only valid at 0.5).""" + alpha = -1.7580993408473766 + keep, q = prob, 1.0 - prob + expected_a = (keep * (1 + q * alpha ** 2)) ** -0.5 + expected_b = -expected_a * q * alpha + m = brainstate.nn.AlphaDropout(prob=prob) + self.assertAlmostEqual(float(m.a), float(expected_a), places=5) + self.assertAlmostEqual(float(m.b), float(expected_b), places=5) + fm = brainstate.nn.FeatureAlphaDropout(prob=prob) + self.assertAlmostEqual(float(fm.a), float(expected_a), places=5) + self.assertAlmostEqual(float(fm.b), float(expected_b), places=5) + + def test_alpha_dropout_self_normalizes_for_nonhalf_prob(self): + """N1: at prob != 0.5 the output should still keep ~zero mean / ~unit variance.""" + with brainstate.random.seed_context(0): + m = brainstate.nn.AlphaDropout(prob=0.8) + x = brainstate.random.randn(200000) + with brainstate.environ.context(fit=True): + out = np.asarray(m(x)) + self.assertLess(abs(out.mean()), 0.05) + self.assertLess(abs(out.std() - 1.0), 0.05) + + def test_dropout2d_mask_is_independent_per_batch_element(self): + """N2: batched Dropout2d must draw an independent channel mask per batch element.""" + with brainstate.random.seed_context(0): + m = brainstate.nn.Dropout2d(prob=0.5) # default channel_axis=-1 (channel-last) + x = brainstate.random.randn(8, 4, 4, 6) # (N, H, W, C) + with brainstate.environ.context(fit=True): + out = np.asarray(m(x)) + # Each channel is dropped as a whole (mask constant over the spatial dims). + dropped = np.all(out == 0, axis=(1, 2)) # (N, C) + kept = np.all(out != 0, axis=(1, 2)) # (N, C) + self.assertTrue(np.all(dropped | kept), "channels must be all-zero or all-kept") + # The drop pattern must NOT be identical across the batch (the bug shared one mask). + self.assertGreater(float(dropped.std(axis=0).sum()), 0.0, + "channel-drop pattern is identical across the batch axis") + + def test_dropout1d_unbatched_still_channelwise(self): + """N2: unbatched Dropout1d still drops whole channels along the last axis.""" + with brainstate.random.seed_context(1): + m = brainstate.nn.Dropout1d(prob=0.5) + x = brainstate.random.randn(10, 6) # (L, C) + with brainstate.environ.context(fit=True): + out = np.asarray(m(x)) + dropped = np.all(out == 0, axis=0) # (C,) + kept = np.all(out != 0, axis=0) + self.assertTrue(np.all(dropped | kept)) + + if __name__ == '__main__': absltest.main() diff --git a/brainstate/nn/_elementwise.py b/brainstate/nn/_elementwise.py index 8cfc8de3..f08ad22e 100644 --- a/brainstate/nn/_elementwise.py +++ b/brainstate/nn/_elementwise.py @@ -158,6 +158,13 @@ class RReLU(ElementWiseBlock): upper : float, optional Upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` + Notes + ----- + Unlike PyTorch's ``RReLU``, this layer samples a fresh random negative slope on + **every** call (there is no separate evaluation mode that uses the fixed midpoint + slope). Inference is therefore non-deterministic; fix ``lower == upper`` for a + deterministic slope. + Shape ----- - Input: :math:`(*)`, where :math:`*` means any number of dimensions. @@ -956,8 +963,10 @@ class PReLU(ElementWiseBlock): Notes ----- - Weight decay should not be used when learning :math:`a` for good performance. - - Channel dim is the 2nd dim of input. When input has dims < 2, then there is - no channel dim and the number of channels = 1. + - Following brainstate's channel-last convention, the per-channel parameter + :math:`a` broadcasts against the **last** axis of the input. When + ``num_parameters > 1``, the size of the input's last dimension must equal + ``num_parameters``. Examples -------- @@ -1083,7 +1092,7 @@ def __init__(self, dim: Optional[int] = None) -> None: self.dim = dim def __call__(self, x: ArrayLike) -> ArrayLike: - return F.softmin(x, self.dim) + return F.softmin(x, -1 if self.dim is None else self.dim) def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})' @@ -1144,7 +1153,7 @@ def __init__(self, dim: Optional[int] = None) -> None: self.dim = dim def __call__(self, x: ArrayLike) -> ArrayLike: - return F.softmax(x, self.dim) + return F.softmax(x, -1 if self.dim is None else self.dim) def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})' @@ -1227,7 +1236,7 @@ def __init__(self, dim: Optional[int] = None) -> None: self.dim = dim def __call__(self, x: ArrayLike) -> ArrayLike: - return F.log_softmax(x, self.dim) + return F.log_softmax(x, -1 if self.dim is None else self.dim) def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})' diff --git a/brainstate/nn/_elementwise_test.py b/brainstate/nn/_elementwise_test.py index dfadc856..8b1fe1ba 100644 --- a/brainstate/nn/_elementwise_test.py +++ b/brainstate/nn/_elementwise_test.py @@ -826,5 +826,30 @@ def single_forward(x): self.assertEqual(output.shape, x.shape) +class TestElementwiseAuditRegressions(parameterized.TestCase): + """Regression tests for bugs found in the nn-module audit.""" + + @parameterized.parameters(nn.Softmax, nn.Softmin, nn.LogSoftmax) + def test_softmax_family_default_dim_is_last_axis(self, cls): + """A1: default dim must normalize over the last axis, not the whole array.""" + x = jnp.asarray([[1., 2., 3.], [4., 5., 6.]]) + out = np.asarray(cls()(x)) + explicit = np.asarray(cls(dim=-1)(x)) + np.testing.assert_allclose(out, explicit, rtol=1e-6) + if cls is nn.LogSoftmax: + probs = np.exp(out) + else: + probs = out + # Each row sums to 1 (the whole-array bug made the grand total 1 instead). + np.testing.assert_allclose(probs.sum(axis=-1), np.ones(2), rtol=1e-5) + + def test_prelu_broadcasts_on_last_axis(self): + """A3: per-channel PReLU weight broadcasts against the last (channel-last) axis.""" + m = nn.PReLU(num_parameters=3, init=0.25) + x = brainstate.random.randn(2, 4, 3) # channel-last, C=3 + out = m(x) + self.assertEqual(out.shape, x.shape) + + if __name__ == '__main__': absltest.main() \ No newline at end of file diff --git a/brainstate/nn/_event_fixedprob.py b/brainstate/nn/_event_fixedprob.py index c25b609b..c07dfe12 100644 --- a/brainstate/nn/_event_fixedprob.py +++ b/brainstate/nn/_event_fixedprob.py @@ -120,6 +120,24 @@ def __init__( self.out_size = out_size self.efferent_target = efferent_target assert efferent_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".' + if efferent_target == 'pre': + # The 'pre' path builds indices shaped ``(out_size, conn_num)`` with + # values in ``[0, in_size)`` and hands them to + # ``brainevent.FixedPreNumConn``/``CSC`` with ``shape=(out_size, in_size)``. + # brainevent then validates the index rows against ``in_size`` and the + # counts disagree whenever ``in_size != out_size``: with + # ``afferent_ratio == 1`` this raises "Pre-synaptic row number mismatch", + # and with ``afferent_ratio < 1`` it corrupts the heap (a native + # ``free(): invalid next size`` abort that can take down the process). + # Until the underlying index layout is corrected, reject this + # configuration up front with a clear, catchable error rather than + # risking a hard crash. + raise NotImplementedError( + "efferent_target='pre' is not currently supported by FixedNumConn: " + "the generated connection indices do not match the layout expected by " + "the underlying brainevent sparse connection, which can raise a shape " + "error or abort the process. Use efferent_target='post' instead." + ) assert 0. <= afferent_ratio <= 1., 'Afferent ratio must be in [0, 1].' if isinstance(conn_num, float): assert 0. <= conn_num <= 1., 'Connection probability must be in [0, 1].' @@ -140,9 +158,13 @@ def __init__( n_post = self.in_size[-1] n_pre = self.out_size[-1] + # A single seeded host-side RNG drives both the connection indices and + # the afferent-ratio pre-selection mask, so both are reproducible from + # ``seed`` (previously ``pre_selected`` used the global ``np.random``, + # ignoring ``seed`` entirely). + rng = np.random if seed is None else np.random.RandomState(seed) with jax.ensure_compile_time_eval(): if allow_multi_conn: - rng = np.random if seed is None else np.random.RandomState(seed) indices = rng.randint(0, n_post, size=(n_pre, self.conn_num)) else: indices = init_indices_without_replace(self.conn_num, n_pre, n_post, seed, conn_init) @@ -159,7 +181,7 @@ def __init__( self.conn = csr else: - self.pre_selected = np.random.random(n_pre) < afferent_ratio + self.pre_selected = rng.random(n_pre) < afferent_ratio indices = indices[self.pre_selected].flatten() conn_weight = u.math.asarray(init.param(conn_weight, (indices.size,), allow_none=False)) self.weight = param_type(conn_weight) diff --git a/brainstate/nn/_event_fixedprob_test.py b/brainstate/nn/_event_fixedprob_test.py index 29bc3748..b1c07230 100644 --- a/brainstate/nn/_event_fixedprob_test.py +++ b/brainstate/nn/_event_fixedprob_test.py @@ -204,6 +204,22 @@ def test_afferent_ratio_post_csr_branch(self): out = m(brainstate.random.rand(20)) assert out.shape == (40,) + def test_afferent_ratio_pre_selected_respects_seed(self): + """D4: ``pre_selected`` is reproducible from ``seed`` (was global np.random).""" + m1 = FixedNumConn(20, 40, 0.2, 1.0, efferent_target='post', + afferent_ratio=0.5, seed=7) + m2 = FixedNumConn(20, 40, 0.2, 1.0, efferent_target='post', + afferent_ratio=0.5, seed=7) + assert np.array_equal(np.asarray(m1.pre_selected), np.asarray(m2.pre_selected)) + + def test_afferent_ratio_pre_selected_differs_across_seeds(self): + """Different seeds yield different ``pre_selected`` masks.""" + m1 = FixedNumConn(60, 40, 0.2, 1.0, efferent_target='post', + afferent_ratio=0.5, seed=1) + m2 = FixedNumConn(60, 40, 0.2, 1.0, efferent_target='post', + afferent_ratio=0.5, seed=2) + assert not np.array_equal(np.asarray(m1.pre_selected), np.asarray(m2.pre_selected)) + class TestFixedNumConnZeroConnection: """Cover the zero-connection (FakeState) path of ``FixedNumConn``.""" @@ -244,28 +260,28 @@ def test_event_update_matches_dense_reference(self): assert np.allclose(out, ref, rtol=1e-4, atol=1e-4) -class TestFixedNumConnKnownBugs: - """Document genuine construction bugs in the ``efferent_target='pre'`` path.""" - - @pytest.mark.skip(reason="BUG: efferent_target='pre' crashes in brainevent." - " FixedNumConn builds indices of shape (n_pre, conn_num) = " - "(out_size, conn_num) but brainevent.FixedPreNumConn(shape=(n_pre, " - "n_post)) validates indices rows against shape[1]=n_post=in_size, " - "raising 'Pre-synaptic row number mismatch. 40 != 20'.") - def test_efferent_target_pre_constructs(self): - """``efferent_target='pre'`` should construct a valid connection (currently crashes).""" - m = FixedNumConn(20, 40, 0.2, 1.0, efferent_target='pre', seed=1) - out = m(brainstate.random.rand(40)) - assert out.shape == (20,) - - @pytest.mark.skip(reason="BUG: efferent_target='pre' with afferent_ratio<1 triggers a native" - " memory abort ('free(): invalid next size (fast)') inside the " - "brainevent CSC path, because indices are sized for (n_pre, conn_num)" - " = (out_size, conn_num) rather than the n_post=in_size rows the CSC " - "shape expects.") - def test_efferent_target_pre_afferent_ratio_constructs(self): - """'pre' target with sub-unity afferent_ratio should build a CSC connection.""" - m = FixedNumConn(20, 40, 0.2, 1.0, efferent_target='pre', +class TestFixedNumConnEfferentTargetPreGuard: + """D7: the broken ``efferent_target='pre'`` path is rejected up front. + + The 'pre' index layout disagrees with the underlying brainevent connection + shape: with ``afferent_ratio == 1`` brainevent raises a shape mismatch, and + with ``afferent_ratio < 1`` it aborts the process via a native heap + corruption. ``FixedNumConn`` now guards against this by raising a clear, + catchable ``NotImplementedError`` before any unsafe construction happens. + """ + + def test_efferent_target_pre_raises_not_implemented(self): + """``efferent_target='pre'`` raises a clear ``NotImplementedError``.""" + with pytest.raises(NotImplementedError): + FixedNumConn(20, 40, 0.2, 1.0, efferent_target='pre', seed=1) + + def test_efferent_target_pre_afferent_ratio_raises_not_implemented(self): + """'pre' with sub-unity afferent_ratio is rejected before the unsafe path.""" + with pytest.raises(NotImplementedError): + FixedNumConn(20, 40, 0.2, 1.0, efferent_target='pre', afferent_ratio=0.5, seed=1) - out = m(brainstate.random.rand(40)) - assert out.shape == (20,) + + def test_efferent_target_pre_square_also_guarded(self): + """Even square in/out sizes are guarded (the path is uniformly unsupported).""" + with pytest.raises(NotImplementedError): + FixedNumConn(20, 20, 0.2, 1.0, efferent_target='pre', seed=1) diff --git a/brainstate/nn/_event_linear.py b/brainstate/nn/_event_linear.py index 730713b2..920e0093 100644 --- a/brainstate/nn/_event_linear.py +++ b/brainstate/nn/_event_linear.py @@ -75,7 +75,26 @@ def __init__( def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]: weight = self.weight.value if u.math.size(weight) == 1: - return u.math.ones(self.out_size) * (u.math.sum(spk) * weight) + # Homogeneous (scalar) weight: every post-synaptic neuron receives the + # same total input. The reduction over ``spk`` must mirror the dense path + # below so the two stay numerically consistent: + # + # - ``float_as_event=True``: the dense ``brainevent.EventArray`` path treats + # each nonzero entry as a unit event, so the *forward* value reduces by + # the event count (``sum(spk != 0)``). Its custom VJP, however, propagates + # the value-sum gradient w.r.t. ``spk`` (as ``spk @ weight`` would). We + # reproduce both: the forward equals the event count while the gradient + # flows through the value sum (the stop-gradient cancels in the forward + # pass but leaves a unit derivative on ``spk``). + # - ``float_as_event=False``: the dense ``spk @ weight`` path sums spike + # *values*, so reduce by the value sum directly. + if self.float_as_event: + n_events = u.math.sum(spk != 0) + value_sum = u.math.sum(spk) + reduced = n_events + (value_sum - jax.lax.stop_gradient(value_sum)) + else: + reduced = u.math.sum(spk) + return u.math.ones(self.out_size) * (reduced * weight) if self.float_as_event: return brainevent.EventArray(spk) @ weight diff --git a/brainstate/nn/_event_linear_test.py b/brainstate/nn/_event_linear_test.py index c705e208..1787bcb9 100644 --- a/brainstate/nn/_event_linear_test.py +++ b/brainstate/nn/_event_linear_test.py @@ -122,3 +122,41 @@ def f2(x, w): o2, r2 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w))) assert (jnp.allclose(o1, o2)) assert (jnp.allclose(r1, r2)) + + @pytest.mark.parametrize('float_as_event', [True, False]) + def test_homogeneous_matches_dense_nonbinary(self, float_as_event): + """Regression for C4: the scalar-weight (homogeneous) path must agree with the + dense/heterogeneous path for non-binary float input, for both float_as_event modes. + + With ``float_as_event=True`` the dense ``brainevent.EventArray`` path counts each + nonzero entry as a unit event, so the homogeneous path must reduce by event count + (``sum(spk != 0)``) rather than by spike value. With ``float_as_event=False`` both + paths sum spike values. + """ + n_in = 10 + n_out = 8 + weight = 1.5 + spk = jnp.asarray([0., 2., 0., 3., 0., 0., 1., 0., 0., 5.]) + + # Homogeneous (scalar weight) path. + homo = brainstate.nn.EventLinear(n_in, n_out, weight, float_as_event=float_as_event) + y_homo = homo(spk) + + # Equivalent dense / heterogeneous path: a full weight matrix filled with the scalar. + dense = brainstate.nn.EventLinear( + n_in, n_out, + braintools.init.Constant(weight), + float_as_event=float_as_event, + ) + y_dense = dense(spk) + + assert jnp.allclose(y_homo, y_dense), ( + f"float_as_event={float_as_event}: homogeneous {y_homo} != dense {y_dense}" + ) + + # Sanity-check the expected reduction explicitly. + if float_as_event: + expected_scalar = jnp.sum(spk != 0) * weight # event count = 4 nonzeros + else: + expected_scalar = jnp.sum(spk) * weight # value sum = 11 + assert jnp.allclose(y_homo, jnp.ones((n_out,)) * expected_scalar) diff --git a/brainstate/nn/_exp_euler.py b/brainstate/nn/_exp_euler.py index 60cf2797..c4cb12a6 100644 --- a/brainstate/nn/_exp_euler.py +++ b/brainstate/nn/_exp_euler.py @@ -96,11 +96,24 @@ def exp_euler_step( **Algorithm:** - The method computes the Jacobian :math:`J = \frac{\partial f}{\partial x}` and - uses the exponential-related function :math:`\varphi(z) = (e^z - 1)/z` to update: + The method computes the *element-wise* (diagonal) derivative + :math:`J_i = \frac{\partial f_i}{\partial x_i}` via ``vector_grad`` and uses the + exponential-related function :math:`\varphi(z) = (e^z - 1)/z` to update each + component independently: .. math:: - x_{n+1} = x_n + dt \cdot \varphi(dt \cdot J) \cdot f(x_n, t_n) + x_{n+1, i} = x_{n, i} + dt \cdot \varphi(dt \cdot J_i) \cdot f_i(x_n, t_n) + + .. important:: + + Only the diagonal of the Jacobian is used; **off-diagonal coupling between + state components is treated explicitly (plain Euler), not integrated + exponentially**. For strongly coupled systems where the cross terms + :math:`\partial f_i / \partial x_j` (:math:`i \neq j`) dominate, this scheme + behaves like forward Euler on those terms. It is exact only for systems whose + linearization is diagonal (each component decays/grows according to its own + :math:`\partial f_i / \partial x_i`), which is the common case for + per-neuron/per-synapse dynamics in spiking networks. For SDEs, a stochastic term is added: diff --git a/brainstate/nn/_hidata.py b/brainstate/nn/_hidata.py index 009e1a70..62ce2513 100644 --- a/brainstate/nn/_hidata.py +++ b/brainstate/nn/_hidata.py @@ -236,7 +236,7 @@ def clone(self) -> 'HiData': cloned_children[k] = v.clone() else: cloned_children[k] = v - return self.__class__(children=cloned_children) + return self.__class__(children=cloned_children, name=self.name) @property def state_size(self) -> int: @@ -272,13 +272,13 @@ def add(self, *args, **updates) -> 'HiData': children[k] = v for k in updates: children[k] = updates[k] - return HiData(children=children) + return self.__class__(children=children, name=self.name) def pop(self, *args) -> 'HiData': children = {k: v for k, v in self.children.items()} for arg in args: children.pop(arg) - return HiData(children=children) + return self.__class__(children=children, name=self.name) def replace(self, **updates) -> 'HiData': """ @@ -296,7 +296,7 @@ def replace(self, **updates) -> 'HiData': children = {k: v for k, v in self.children.items()} for k in updates: children[k] = updates[k] - return self.__class__(children=children) + return self.__class__(children=children, name=self.name) def to_dict(self) -> Dict: """ diff --git a/brainstate/nn/_hidata_test.py b/brainstate/nn/_hidata_test.py index 3405937e..87119110 100644 --- a/brainstate/nn/_hidata_test.py +++ b/brainstate/nn/_hidata_test.py @@ -542,5 +542,36 @@ def test_dtype_skips_array_less_nested_child(self): self.assertEqual(outer.dtype, jnp.float32) +class TestHiDataNamePreservation(unittest.TestCase): + """Regression tests for audit finding T6 (name dropped by clone/add/pop/replace).""" + + def test_clone_preserves_name(self): + d = HiData(children={'x': jnp.array([1.0])}, name='layer1') + self.assertEqual(d.clone().name, 'layer1') + + def test_add_preserves_name(self): + d = HiData(children={'x': jnp.array([1.0])}, name='layer1') + self.assertEqual(d.add(y=jnp.array([2.0])).name, 'layer1') + + def test_pop_preserves_name(self): + d = HiData(children={'x': jnp.array([1.0]), 'y': jnp.array([2.0])}, name='layer1') + popped = d.pop('y') + self.assertEqual(popped.name, 'layer1') + self.assertNotIn('y', popped) + + def test_replace_preserves_name(self): + d = HiData(children={'x': jnp.array([1.0])}, name='layer1') + self.assertEqual(d.replace(x=jnp.array([5.0])).name, 'layer1') + + def test_clone_preserves_subclass_and_name(self): + class MyState(HiData): + pass + + d = MyState(children={'x': jnp.array([1.0])}, name='custom') + cloned = d.clone() + self.assertIsInstance(cloned, MyState) + self.assertEqual(cloned.name, 'custom') + + if __name__ == '__main__': unittest.main() diff --git a/brainstate/nn/_linear.py b/brainstate/nn/_linear.py index c8cb5d05..5c187117 100644 --- a/brainstate/nn/_linear.py +++ b/brainstate/nn/_linear.py @@ -313,15 +313,15 @@ def __init__( 'and "out_size" must be the same.') # w_mask - self.w_mask = init.param(w_mask, (self.in_size[0], 1)) + self.w_mask = init.param(w_mask, (self.in_size[-1], 1)) # parameters self.eps = eps # weights - params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False)) + params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)) if b_init is not None: - params['bias'] = init.param(b_init, self.out_size, allow_none=False) + params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False) # gain if ws_gain: s = params['weight'].shape @@ -531,9 +531,13 @@ def update(self, pre_val): val = pre_val[..., :self.out_size[-1]] post_val = post_val - val else: - size = list(self.out_size) - size[-1] = self.out_size[-1] - self.in_size[-1] - val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)]) + # out_size > in_size: pad the per-unit "self" value with zeros for + # the extra output units, matching pre_val's batch rank, and + # concatenate along the feature (last) axis. + pad_shape = pre_val.shape[:-1] + (self.out_size[-1] - self.in_size[-1],) + val = u.math.concatenate( + [pre_val, u.math.zeros(pad_shape, dtype=pre_val.dtype)], axis=-1 + ) post_val = post_val - val post_val = w_val * post_val diff --git a/brainstate/nn/_linear_test.py b/brainstate/nn/_linear_test.py index 9869f8b1..eb1b786f 100644 --- a/brainstate/nn/_linear_test.py +++ b/brainstate/nn/_linear_test.py @@ -471,5 +471,28 @@ def test_sparse_linear_invalid_input(self): brainstate.nn.SparseLinear(jnp.ones((5, 5))) +class TestLinearAuditRegressions(parameterized.TestCase): + """Regression tests for bugs found in the nn-module audit.""" + + def test_scaled_ws_linear_multidim_sizes(self): + """L1: ScaledWSLinear must build a 2D (in[-1], out[-1]) weight for multi-dim sizes.""" + layer = brainstate.nn.ScaledWSLinear((4, 10), (4, 5)) + self.assertEqual(layer.weight.value['weight'].shape, (10, 5)) + x = jnp.ones((4, 10)) + y = layer(x) + self.assertEqual(y.shape, (4, 5)) + + def test_all_to_all_scalar_weight_out_greater_than_in_no_self(self): + """L2: AllToAll scalar weight + include_self=False must work when out_size > in_size.""" + layer = brainstate.nn.AllToAll((3,), (5,), include_self=False) + layer.weight.value = {'weight': jnp.asarray(2.0)} + # unbatched + out = layer(jnp.ones((3,))) + self.assertEqual(out.shape, (5,)) + # batched + out_b = layer(jnp.ones((2, 3))) + self.assertEqual(out_b.shape, (2, 5)) + + if __name__ == '__main__': unittest.main() diff --git a/brainstate/nn/_metrics.py b/brainstate/nn/_metrics.py index 4c5b1622..80ad2772 100644 --- a/brainstate/nn/_metrics.py +++ b/brainstate/nn/_metrics.py @@ -351,7 +351,7 @@ def reset(self) -> None: This resets count, mean, and the sum of squared deviations (m2). """ - self.count.value = jnp.array(0, dtype=jnp.uint32) + self.count.value = jnp.array(0, dtype=jnp.int32) self.mean.value = jnp.array(0, dtype=jnp.float32) self.m2.value = jnp.array(0, dtype=jnp.float32) @@ -552,6 +552,10 @@ class PrecisionMetric(Metric): __module__ = "brainstate.nn" def __init__(self, num_classes: tp.Optional[int] = None, average: str = 'macro'): + if average not in ('micro', 'macro', 'weighted'): + raise ValueError( + f"`average` must be one of 'micro', 'macro', 'weighted'. Got {average!r}." + ) self.num_classes = num_classes self.average = average if num_classes is None: @@ -560,6 +564,9 @@ def __init__(self, num_classes: tp.Optional[int] = None, average: str = 'macro') else: self.true_positives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32)) self.false_positives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32)) + # Per-class support (number of true labels of each class), used by the + # 'weighted' average. + self.support = MetricState(jnp.zeros(num_classes, dtype=jnp.int32)) def reset(self) -> None: """Reset the metric state to zero.""" @@ -569,6 +576,7 @@ def reset(self) -> None: else: self.true_positives.value = jnp.zeros(self.num_classes, dtype=jnp.int32) self.false_positives.value = jnp.zeros(self.num_classes, dtype=jnp.int32) + self.support.value = jnp.zeros(self.num_classes, dtype=jnp.int32) def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None: """ @@ -596,6 +604,9 @@ def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None: self.false_positives.value = self.false_positives.value.at[c].add( jnp.sum((predictions == c) & (labels != c)) ) + self.support.value = self.support.value.at[c].add( + jnp.sum(labels == c) + ) def compute(self) -> jax.Array: """ @@ -625,6 +636,14 @@ def compute(self) -> jax.Array: total_tp / (total_tp + total_fp), jnp.float32(0.0) ) + elif self.num_classes is not None and self.average == 'weighted': + support = self.support.value + total = jnp.sum(support) + return jnp.where( + total > 0, + jnp.sum(precision * support) / total, + jnp.float32(0.0) + ) return precision @@ -675,6 +694,10 @@ class RecallMetric(Metric): __module__ = "brainstate.nn" def __init__(self, num_classes: tp.Optional[int] = None, average: str = 'macro'): + if average not in ('micro', 'macro', 'weighted'): + raise ValueError( + f"`average` must be one of 'micro', 'macro', 'weighted'. Got {average!r}." + ) self.num_classes = num_classes self.average = average if num_classes is None: @@ -748,6 +771,14 @@ def compute(self) -> jax.Array: total_tp / (total_tp + total_fn), jnp.float32(0.0) ) + elif self.num_classes is not None and self.average == 'weighted': + # Per-class support equals TP + FN, i.e. the recall denominator. + total = jnp.sum(denominator) + return jnp.where( + total > 0, + jnp.sum(recall * denominator) / total, + jnp.float32(0.0) + ) return recall @@ -867,8 +898,8 @@ class ConfusionMatrix(Metric): >>> metric = brainstate.nn.ConfusionMatrix(num_classes=3) >>> metric.update(predictions=predictions, labels=labels) >>> metric.compute() - Array([[1, 0, 1], - [0, 2, 0], + Array([[1, 0, 0], + [0, 2, 1], [1, 0, 0]], dtype=int32) Notes diff --git a/brainstate/nn/_metrics_test.py b/brainstate/nn/_metrics_test.py index bfabcb47..0da77b51 100644 --- a/brainstate/nn/_metrics_test.py +++ b/brainstate/nn/_metrics_test.py @@ -607,5 +607,43 @@ def test_precision_recall_no_positives(self): self.assertEqual(float(recall), 0.0) +class MetricsAuditRegressionTest(parameterized.TestCase): + """Regression tests for bugs found in the nn-module audit.""" + + def test_weighted_average_returns_scalar(self): + """E1: 'weighted' averaging must produce a support-weighted scalar, not an array.""" + preds = jnp.array([0, 0, 1, 1, 1, 2, 2, 2, 2]) + labels = jnp.array([0, 1, 1, 1, 2, 2, 2, 2, 0]) + for cls in (bst.nn.PrecisionMetric, bst.nn.RecallMetric, bst.nn.F1ScoreMetric): + m = cls(num_classes=3, average='weighted') + m.update(predictions=preds, labels=labels) + out = m.compute() + self.assertEqual(jnp.asarray(out).shape, ()) # scalar, not per-class array + self.assertTrue(0.0 <= float(out) <= 1.0) + + def test_weighted_recall_matches_manual_support_weighting(self): + """E1: weighted recall == sum(recall_c * support_c) / sum(support_c).""" + preds = jnp.array([0, 0, 1, 1, 1, 2, 2, 2, 2]) + labels = jnp.array([0, 1, 1, 1, 2, 2, 2, 2, 0]) + m = bst.nn.RecallMetric(num_classes=3, average='weighted') + m.update(predictions=preds, labels=labels) + # support: class0=2, class1=3, class2=4 ; recall0=1/2, recall1=2/3, recall2=3/4 + expected = (0.5 * 2 + (2 / 3) * 3 + 0.75 * 4) / 9 + self.assertAlmostEqual(float(m.compute()), expected, places=5) + + @parameterized.parameters(bst.nn.PrecisionMetric, bst.nn.RecallMetric, bst.nn.F1ScoreMetric) + def test_invalid_average_raises(self, cls): + """E1: an unknown `average` must raise instead of silently misbehaving.""" + with self.assertRaises(ValueError): + cls(num_classes=3, average='Macro') + + def test_welford_reset_keeps_int32_count(self): + """E4: WelfordMetric.reset must keep count dtype int32 (matching __init__).""" + m = bst.nn.WelfordMetric() + m.update(values=jnp.array([1.0, 2.0, 3.0])) + m.reset() + self.assertEqual(m.count.value.dtype, jnp.int32) + + if __name__ == '__main__': absltest.main() diff --git a/brainstate/nn/_module.py b/brainstate/nn/_module.py index 819d4afc..ecb42fee 100644 --- a/brainstate/nn/_module.py +++ b/brainstate/nn/_module.py @@ -48,6 +48,38 @@ max_int = np.iinfo(np.int32).max +def _format_size_arg(size, attr_name: str) -> tuple: + """Normalize an ``in_size``/``out_size`` argument to a tuple of ints. + + Accepts a Python ``int``, any 0-dimensional integer numpy value (both + ``np.generic`` scalars such as ``np.int64(5)`` and 0-d ``np.ndarray`` such as + ``np.array(5)``), or an existing tuple/list. This unifies the two size + setters, which previously handled only one of the numpy scalar forms each and + additionally called ``np.issubdtype`` on an array instance (a ``TypeError``). + + Parameters + ---------- + size : int, numpy scalar/0-d array, tuple, or list + The raw size specification. + attr_name : str + Name of the attribute (for error messages). + + Returns + ------- + tuple + The size as a tuple of ints. + """ + if isinstance(size, int): + return (size,) + # Both numpy scalars (np.generic) and 0-d ndarrays report ndim == 0; compare + # on the dtype (never the instance) so np.issubdtype receives a valid input. + if isinstance(size, (np.generic, np.ndarray)) and np.ndim(size) == 0: + if np.issubdtype(np.asarray(size).dtype, np.integer): + return (int(size),) + assert isinstance(size, (tuple, list)), f"Invalid type of {attr_name}: {type(size)}" + return tuple(size) + + class Module(Node, ParamDesc): """ Base class for neural network modules in BrainState. @@ -130,13 +162,7 @@ def in_size(self) -> Size: @in_size.setter def in_size(self, in_size: Sequence[int] | int): - if isinstance(in_size, int): - in_size = (in_size,) - elif isinstance(in_size, np.generic): - if np.issubdtype(in_size, np.integer) and in_size.ndim == 0: - in_size = (int(in_size),) - assert isinstance(in_size, (tuple, list)), f"Invalid type of in_size: {in_size} {type(in_size)}" - self._in_size = tuple(in_size) + self._in_size = _format_size_arg(in_size, 'in_size') @property def out_size(self) -> Size: @@ -144,13 +170,7 @@ def out_size(self) -> Size: @out_size.setter def out_size(self, out_size: Sequence[int] | int): - if isinstance(out_size, int): - out_size = (out_size,) - elif isinstance(out_size, np.ndarray): - if np.issubdtype(out_size, np.integer) and out_size.ndim == 0: - out_size = (int(out_size),) - assert isinstance(out_size, (tuple, list)), f"Invalid type of out_size: {type(out_size)}" - self._out_size = tuple(out_size) + self._out_size = _format_size_arg(out_size, 'out_size') @not_implemented def update(self, *args, **kwargs): @@ -777,14 +797,29 @@ def update(self, x): def __getitem__(self, key: Union[int, slice]): if isinstance(key, slice): - return Sequential(*self.layers[key]) + return self._from_layers(self.layers[key]) elif isinstance(key, int): return self.layers[key] elif isinstance(key, (tuple, list)): - return Sequential(*[self.layers[k] for k in key]) + return self._from_layers([self.layers[k] for k in key]) else: raise KeyError(f'Unknown type of key: {type(key)}') + @classmethod + def _from_layers(cls, layers: Sequence[Module]) -> 'Sequential': + """Build a ``Sequential`` from an already-formatted layer list. + + Handles the empty case (e.g. an out-of-range or degenerate slice), which + ``cls(*layers)`` cannot because ``__init__`` requires at least one layer. + """ + layers = list(layers) + if len(layers) == 0: + seq = cls.__new__(cls) + Module.__init__(seq) + seq.layers = [] + return seq + return cls(*layers) + def append(self, layer: Callable): """ Append a layer to the sequential model. diff --git a/brainstate/nn/_module_test.py b/brainstate/nn/_module_test.py index 68598261..7e83fd6c 100644 --- a/brainstate/nn/_module_test.py +++ b/brainstate/nn/_module_test.py @@ -1150,20 +1150,63 @@ def test_insert_module_first_ok(self): self.assertEqual(seq.in_size, (3,)) self.assertEqual(seq.out_size, (4,)) + def test_empty_slice_returns_empty_sequential(self): + """M6: an out-of-range slice yields an empty Sequential, not a crash.""" + seq = brainstate.nn.Sequential(brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)) + empty = seq[10:20] + self.assertIsInstance(empty, Sequential) + self.assertEqual(len(empty.layers), 0) + + def test_degenerate_slice_returns_empty_sequential(self): + """M6: a degenerate ``[i:i]`` slice yields an empty Sequential.""" + seq = brainstate.nn.Sequential(brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)) + empty = seq[1:1] + self.assertIsInstance(empty, Sequential) + self.assertEqual(len(empty.layers), 0) + + def test_empty_tuple_index_returns_empty_sequential(self): + """M6: indexing with an empty tuple yields an empty Sequential.""" + seq = brainstate.nn.Sequential(brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)) + empty = seq[()] + self.assertIsInstance(empty, Sequential) + self.assertEqual(len(empty.layers), 0) + + def test_nonempty_slice_still_works(self): + """A normal sub-slice still produces a populated Sequential.""" + seq = brainstate.nn.Sequential( + brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5), brainstate.nn.Linear(5, 6) + ) + sub = seq[1:] + self.assertIsInstance(sub, Sequential) + self.assertEqual(len(sub.layers), 2) + class TestModuleSizeSetterEdges(unittest.TestCase): """Cover numpy-typed and invalid inputs to the size setters.""" - def test_out_size_zero_d_integer_ndarray_bug(self): - """A 0-d integer ``np.ndarray`` ``out_size`` hits a source bug (documents BUG). + def test_out_size_zero_d_integer_ndarray(self): + """M7: a 0-d integer ``np.ndarray`` ``out_size`` is normalized to a tuple.""" + mod = Module() + mod.out_size = np.array(6) + self.assertEqual(mod.out_size, (6,)) - The setter calls ``np.issubdtype(out_size, np.integer)`` passing the array - itself rather than its ``.dtype``; numpy raises ``TypeError`` ("Cannot - construct a dtype from an array"), so a 0-d integer array cannot be used. - """ + def test_out_size_numpy_integer_scalar(self): + """M7: a numpy integer scalar (``np.generic``) ``out_size`` is accepted.""" mod = Module() - with self.assertRaises(TypeError): - mod.out_size = np.array(6) + mod.out_size = np.int64(5) + self.assertEqual(mod.out_size, (5,)) + + def test_in_size_zero_d_integer_ndarray(self): + """M7: a 0-d integer ``np.ndarray`` ``in_size`` is normalized to a tuple.""" + mod = Module() + mod.in_size = np.array(4) + self.assertEqual(mod.in_size, (4,)) + + def test_in_size_numpy_integer_scalar(self): + """A numpy integer scalar (``np.generic``) ``in_size`` is accepted.""" + mod = Module() + mod.in_size = np.int32(3) + self.assertEqual(mod.in_size, (3,)) def test_in_size_non_integer_generic_raises(self): """A non-integer ``np.generic`` ``in_size`` fails the type assertion.""" @@ -1171,6 +1214,12 @@ def test_in_size_non_integer_generic_raises(self): with self.assertRaises(AssertionError): mod.in_size = np.float64(3.0) + def test_out_size_non_integer_zero_d_ndarray_raises(self): + """A non-integer 0-d ``np.ndarray`` ``out_size`` fails the type assertion.""" + mod = Module() + with self.assertRaises(AssertionError): + mod.out_size = np.array(3.0) + class _UnsizedModule(Module): """Module without declared in/out sizes, used for Sequential size tests.""" diff --git a/brainstate/nn/_param.py b/brainstate/nn/_param.py index 096238cd..749c47a3 100644 --- a/brainstate/nn/_param.py +++ b/brainstate/nn/_param.py @@ -231,8 +231,13 @@ def value(self) -> ArrayLike: """ Get current parameter value after applying transform. - Returns cached value when valid. Otherwise, computes ``t.forward(val)``, - caches it, and returns the result. + Returns the cached value when a valid cache exists. The cache is *opt-in*: + it is populated only by an explicit call to :meth:`cache` (and invalidated + automatically on writes / :meth:`set_value` / :meth:`clear_cache`). When no + valid cache exists, this method computes ``t.forward(val)`` fresh on every + call **without** populating the cache. This avoids caching a traced value + across ``jit`` boundaries; call :meth:`cache` explicitly to memoize when the + underlying value is stable. Returns ------- @@ -464,7 +469,12 @@ class Const(Param): A module has non-trainable constant parameter. A convenience class that creates a fixed (non-trainable) parameter. - Equivalent to ``ParamM(value, fit=False)``. + Equivalent to ``Param(value, fit=False)``. + + "Non-trainable" means the value is **not** wrapped in a trainable ``ParamState`` + and is therefore excluded from gradient-based optimization. It does **not** mean + the value is immutable: it can still be updated explicitly via :meth:`set_value` + or :meth:`clip`. Parameters ---------- diff --git a/brainstate/nn/_poolings.py b/brainstate/nn/_poolings.py index b4d6c813..d73ab81f 100644 --- a/brainstate/nn/_poolings.py +++ b/brainstate/nn/_poolings.py @@ -313,9 +313,9 @@ def reducer(acc, operand): def _infer_shape(self, x_dim, inputs, element): channel_axis = self.channel_axis - if channel_axis and not 0 <= abs(channel_axis) < x_dim: + if channel_axis is not None and not -x_dim <= channel_axis < x_dim: raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions") - if channel_axis and channel_axis < 0: + if channel_axis is not None and channel_axis < 0: channel_axis = x_dim + channel_axis all_dims = list(range(x_dim)) if channel_axis is not None: @@ -370,16 +370,15 @@ class MaxPool1d(_MaxPool): input(N_i, stride \times k + m, C_j) If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides - for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the - sliding window. This `link`_ has a nice visualization of the pooling parameters. + for :attr:`padding` number of points. This `link`_ has a nice visualization of the pooling parameters. Shape: - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`. - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where .. math:: - L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation} - \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor + L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} + - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor Parameters ---------- @@ -457,20 +456,19 @@ class MaxPool2d(_MaxPool): \end{aligned} If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides - for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. - It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + for :attr:`padding` number of points. This `link`_ has a nice visualization of the pooling parameters. Shape: - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)` - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where .. math:: - H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]} - \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor + H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} + - \text{kernel\_size[0]}}{\text{stride[0]}} + 1\right\rfloor .. math:: - W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]} - \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor + W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} + - \text{kernel\_size[1]}}{\text{stride[1]}} + 1\right\rfloor Parameters ---------- @@ -550,24 +548,23 @@ class MaxPool3d(_MaxPool): \end{aligned} If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides - for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. - It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + for :attr:`padding` number of points. This `link`_ has a nice visualization of the pooling parameters. Shape: - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`. - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where .. math:: - D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] + - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor .. math:: - H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] + - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor .. math:: - W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] + - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor Parameters ---------- @@ -710,23 +707,33 @@ def _compute_output_shape(self, input_shape, output_size=None): return tuple(output_shape) def _unpool_nd(self, x, indices, output_size=None): - """Perform N-dimensional max unpooling.""" + """Perform N-dimensional max unpooling. + + The ``indices`` returned by :class:`MaxPool` are flat positions within the + *natural* (inferred) output shape, i.e. positions in the layout that the + original pre-pool input had. To support an arbitrary ``output_size`` (whose + spatial extent may differ from the natural one) without leaking values + across batch or channel elements, each stored flat index is first unraveled + into a multi-dimensional coordinate against the natural shape and then + re-raveled against the requested output shape. Coordinates that fall outside + the requested output are dropped. + """ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) if x.ndim < x_dim: raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.') - # Determine output shape + # Natural output shape: the layout the stored ``indices`` are flat positions in. + spatial_dims = self._get_spatial_dims(x.shape) + natural_spatial_shape = self._compute_output_shape(spatial_dims, None) + natural_shape = list(x.shape) + spatial_start = self._get_spatial_start_idx(x.ndim) + for i, size in enumerate(natural_spatial_shape): + natural_shape[spatial_start + i] = size + natural_shape = tuple(natural_shape) + + # Determine the requested output shape. if output_size is None: - # Infer output shape from input shape - spatial_dims = self._get_spatial_dims(x.shape) - output_spatial_shape = self._compute_output_shape(spatial_dims, output_size) - output_shape = list(x.shape) - - # Update spatial dimensions in output shape - spatial_start = self._get_spatial_start_idx(x.ndim) - for i, size in enumerate(output_spatial_shape): - output_shape[spatial_start + i] = size - output_shape = tuple(output_shape) + output_shape = natural_shape else: # Use provided output size if isinstance(output_size, (list, tuple)): @@ -738,64 +745,48 @@ def _unpool_nd(self, x, indices, output_size=None): if len(output_size) != self.pool_dim: raise ValueError(f"output_size must have {self.pool_dim} spatial dimensions, got {len(output_size)}") output_shape = list(x.shape) - spatial_start = self._get_spatial_start_idx(x.ndim) for i, size in enumerate(output_size): output_shape[spatial_start + i] = size output_shape = tuple(output_shape) else: # Single integer provided, use for all spatial dims output_shape = list(x.shape) - spatial_start = self._get_spatial_start_idx(x.ndim) for i in range(self.pool_dim): output_shape[spatial_start + i] = output_size output_shape = tuple(output_shape) - # Create output array filled with zeros - output = jnp.zeros(output_shape, dtype=x.dtype) - - # # Scatter input values to output using indices - # # Flatten spatial dimensions for easier indexing - # batch_dims = x.ndim - self.pool_dim - (0 if self.channel_axis is None else 1) - # - # # Reshape for processing - # if batch_dims > 0: - # batch_shape = x.shape[:batch_dims] - # if self.channel_axis is not None and self.channel_axis < batch_dims: - # # Channel axis is before spatial dims - # channel_idx = self.channel_axis - # n_channels = x.shape[channel_idx] - # elif self.channel_axis is not None: - # # Channel axis is after spatial dims - # if self.channel_axis < 0: - # channel_idx = x.ndim + self.channel_axis - # else: - # channel_idx = self.channel_axis - # n_channels = x.shape[channel_idx] - # else: - # n_channels = None - # else: - # batch_shape = () - # if self.channel_axis is not None: - # if self.channel_axis < 0: - # channel_idx = x.ndim + self.channel_axis - # else: - # channel_idx = self.channel_axis - # n_channels = x.shape[channel_idx] - # else: - # n_channels = None - - # Use JAX's scatter operation - # Flatten the indices to 1D for scatter + # Flatten values and the stored within-natural-layout indices. flat_indices = indices.ravel() flat_values = x.ravel() - flat_output = output.ravel() - - # Scatter the values - flat_output = flat_output.at[flat_indices].set(flat_values) - - # Reshape back to original shape - output = flat_output.reshape(output_shape) + if output_shape == natural_shape: + # Fast path: natural-size round-trip, indices already match the output + # layout exactly, so scatter directly. + flat_output = jnp.zeros(int(np.prod(output_shape)), dtype=x.dtype) + flat_output = flat_output.at[flat_indices].set(flat_values) + return flat_output.reshape(output_shape) + + # General path: map each natural-layout flat index to its multi-dimensional + # coordinate, then re-flatten against the requested output shape so that + # every value stays within its own batch / channel element. + coords = jnp.unravel_index(flat_indices, natural_shape) + + # Determine which coordinates remain inside the requested output bounds. + in_bounds = jnp.ones(flat_indices.shape, dtype=bool) + for coord, dim in zip(coords, output_shape): + in_bounds = in_bounds & (coord < dim) + + # Clip coordinates so ravel stays valid, then route out-of-bounds entries to + # a throwaway slot appended to the flat output (dropped after the scatter). + clipped_coords = tuple(jnp.minimum(coord, dim - 1) for coord, dim in zip(coords, output_shape)) + out_flat_size = int(np.prod(output_shape)) + new_flat = jnp.ravel_multi_index(clipped_coords, output_shape, mode='clip') + dump_slot = out_flat_size + scatter_indices = jnp.where(in_bounds, new_flat, dump_slot) + + flat_output = jnp.zeros(out_flat_size + 1, dtype=x.dtype) + flat_output = flat_output.at[scatter_indices].set(flat_values) + output = flat_output[:out_flat_size].reshape(output_shape) return output def _get_spatial_dims(self, shape): @@ -1070,6 +1061,14 @@ class AvgPool1d(_AvgPool): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points. + .. note:: + Padding cells are **excluded** from the average: each output is divided by the + number of *valid* (non-padded) input elements that fall inside its window, not + by the full window size. This corresponds to ``count_include_pad=False`` in + PyTorch. As a result, the ``1 / k`` divisor in the formula above only holds for + windows that contain no padding; windows overlapping the padded border are + divided by their valid-element count instead. + Shape: - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`. - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where @@ -1147,6 +1146,14 @@ class AvgPool2d(_AvgPool): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points. + .. note:: + Padding cells are **excluded** from the average: each output is divided by the + number of *valid* (non-padded) input elements that fall inside its window, not + by the full window size. This corresponds to ``count_include_pad=False`` in + PyTorch. As a result, the ``1 / k`` divisor in the formula above only holds for + windows that contain no padding; windows overlapping the padded border are + divided by their valid-element count instead. + Shape: - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`. - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where @@ -1235,6 +1242,14 @@ class AvgPool3d(_AvgPool): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides for :attr:`padding` number of points. + .. note:: + Padding cells are **excluded** from the average: each output is divided by the + number of *valid* (non-padded) input elements that fall inside its window, not + by the full window size. This corresponds to ``count_include_pad=False`` in + PyTorch. As a result, the ``1 / (kD \times kH \times kW)`` divisor in the formula + above only holds for windows that contain no padding; windows overlapping the + padded border are divided by their valid-element count instead. + Shape: - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`. - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or @@ -1433,9 +1448,9 @@ def update(self, x): def _infer_shape(self, x_dim, inputs, element): channel_axis = self.channel_axis - if channel_axis and not 0 <= abs(channel_axis) < x_dim: + if channel_axis is not None and not -x_dim <= channel_axis < x_dim: raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions") - if channel_axis and channel_axis < 0: + if channel_axis is not None and channel_axis < 0: channel_axis = x_dim + channel_axis all_dims = list(range(x_dim)) if channel_axis is not None: @@ -1450,15 +1465,23 @@ def _infer_shape(self, x_dim, inputs, element): class LPPool1d(_LPPool): r"""Applies a 1D power-average pooling over an input signal composed of several input planes. - On each window, the function computed is: + On each window, the function computed is the (normalized) power-mean: .. math:: - f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}} + f(X) = \left( \frac{1}{N} \sum_{x \in X} |x|^{p} \right)^{1/p} + + where :math:`N` is the number of elements in the window + (:math:`N = \prod_i \text{kernel\_size}[i]`). - At :math:`p = \infty`, one gets max pooling - At :math:`p = 1`, one gets average pooling (with absolute values) - At :math:`p = 2`, one gets root mean square (RMS) pooling + .. note:: + This is a *normalized* power-mean (the sum is divided by the window size + :math:`N`). It therefore differs from PyTorch's ``LPPool``, which computes + the *unnormalized* power-sum :math:`\left( \sum_{x \in X} |x|^{p} \right)^{1/p}`. + Shape: - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`. - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where @@ -1526,15 +1549,23 @@ def __init__( class LPPool2d(_LPPool): r"""Applies a 2D power-average pooling over an input signal composed of several input planes. - On each window, the function computed is: + On each window, the function computed is the (normalized) power-mean: .. math:: - f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}} + f(X) = \left( \frac{1}{N} \sum_{x \in X} |x|^{p} \right)^{1/p} + + where :math:`N` is the number of elements in the window + (:math:`N = \prod_i \text{kernel\_size}[i]`). - At :math:`p = \infty`, one gets max pooling - At :math:`p = 1`, one gets average pooling (with absolute values) - At :math:`p = 2`, one gets root mean square (RMS) pooling + .. note:: + This is a *normalized* power-mean (the sum is divided by the window size + :math:`N`). It therefore differs from PyTorch's ``LPPool``, which computes + the *unnormalized* power-sum :math:`\left( \sum_{x \in X} |x|^{p} \right)^{1/p}`. + Shape: - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)` - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where @@ -1607,15 +1638,23 @@ def __init__( class LPPool3d(_LPPool): r"""Applies a 3D power-average pooling over an input signal composed of several input planes. - On each window, the function computed is: + On each window, the function computed is the (normalized) power-mean: .. math:: - f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}} + f(X) = \left( \frac{1}{N} \sum_{x \in X} |x|^{p} \right)^{1/p} + + where :math:`N` is the number of elements in the window + (:math:`N = \prod_i \text{kernel\_size}[i]`). - At :math:`p = \infty`, one gets max pooling - At :math:`p = 1`, one gets average pooling (with absolute values) - At :math:`p = 2`, one gets root mean square (RMS) pooling + .. note:: + This is a *normalized* power-mean (the sum is divided by the window size + :math:`N`). It therefore differs from PyTorch's ``LPPool``, which computes + the *unnormalized* power-sum :math:`\left( \sum_{x \in X} |x|^{p} \right)^{1/p}`. + Shape: - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`. - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where @@ -1784,8 +1823,8 @@ def update(self, x): # channel axis channel_axis = self.channel_axis - if channel_axis: - if not 0 <= abs(channel_axis) < x.ndim: + if channel_axis is not None: + if not -x.ndim <= channel_axis < x.ndim: raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}") if channel_axis < 0: channel_axis = x.ndim + channel_axis @@ -1801,9 +1840,13 @@ def update(self, x): # pooling for i, di in enumerate(pool_dims[-len(self.target_shape):]): + target = self.target_shape[i] + # A ``None`` target means "do not pool this axis" -- leave it unchanged. + if target is None: + continue poo_axes = [j for j in range(x.ndim) if j != di] op = _generate_vmap(_adaptive_pool1d, poo_axes) - x = op(x, self.target_shape[i], self.operation) + x = op(x, target, self.operation) return x diff --git a/brainstate/nn/_poolings_test.py b/brainstate/nn/_poolings_test.py index 83e55599..c2d0ec18 100644 --- a/brainstate/nn/_poolings_test.py +++ b/brainstate/nn/_poolings_test.py @@ -563,6 +563,61 @@ def test_maxunpool1d_with_output_size(self): self.assertEqual(unpooled.shape, (1, 10, 2)) + def test_maxunpool1d_natural_roundtrip_no_cross_batch(self): + """Natural-size unpool keeps each batch element's maxima in that element.""" + # input (2, 4, 1): batch 0 -> [0,1,2,3], batch 1 -> [4,5,6,7] + arr = jnp.arange(8.0).reshape(2, 4, 1) + + pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True) + pooled, indices = pool(arr) + + unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1) + unpooled = unpool(pooled, indices) + + self.assertEqual(unpooled.shape, (2, 4, 1)) + # Each batch element's maxima land in the SAME batch element. + np.testing.assert_array_equal( + np.asarray(unpooled[..., 0]), + np.array([[0.0, 1.0, 0.0, 3.0], + [0.0, 5.0, 0.0, 7.0]]), + ) + + def test_maxunpool1d_output_size_no_cross_batch_leakage(self): + """A non-natural output_size must not leak maxima across batch elements. + + Regression test for the flat-scatter bug where, for batch N>1, values + landed in the wrong batch element because the per-batch flat stride of the + output differs from that of the input layout when ``output_size`` changes + the spatial extent. + """ + # input (2, 4, 1): batch 0 -> [0,1,2,3], batch 1 -> [4,5,6,7] + arr = jnp.arange(8.0).reshape(2, 4, 1) + + pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True) + pooled, indices = pool(arr) + # pooled maxima: batch 0 -> [1, 3], batch 1 -> [5, 7] + + unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1) + # Request a spatial size (6) that differs from the natural size (4). + unpooled = unpool(pooled, indices, output_size=(2, 6, 1)) + + self.assertEqual(unpooled.shape, (2, 6, 1)) + + b0 = np.asarray(unpooled[0, :, 0]) + b1 = np.asarray(unpooled[1, :, 0]) + + # Batch 0 must contain exactly its own maxima {1, 3} and nothing from batch 1. + self.assertEqual(set(np.unique(b0[b0 != 0]).tolist()), {1.0, 3.0}) + self.assertNotIn(5.0, b0.tolist()) + self.assertNotIn(7.0, b0.tolist()) + # Batch 1 must contain exactly its own maxima {5, 7}. + self.assertEqual(set(np.unique(b1[b1 != 0]).tolist()), {5.0, 7.0}) + # Positions within each batch element are preserved (col index = original). + self.assertEqual(b0[1], 1.0) + self.assertEqual(b0[3], 3.0) + self.assertEqual(b1[1], 5.0) + self.assertEqual(b1[3], 7.0) + class TestMaxUnpool2d(parameterized.TestCase): """Comprehensive tests for MaxUnpool2d.""" @@ -1065,6 +1120,29 @@ def test_invalid_channel_axis_raises(self): with self.assertRaises(ValueError): p(brainstate.random.randn(3, 4)) + def test_most_negative_channel_axis_accepted(self): + """``channel_axis == -ndim`` is valid and matches the positive equivalent. + + Regression test: the previous guard rejected the most-negative valid axis. + """ + x = brainstate.random.randn(4, 10) + out_neg = nn.MaxPool1d(2, channel_axis=-2)(x) # -2 == -ndim for a 2D input + out_pos = nn.MaxPool1d(2, channel_axis=0)(x) + self.assertEqual(out_neg.shape, (4, 5)) + np.testing.assert_array_equal(np.asarray(out_neg), np.asarray(out_pos)) + + def test_too_negative_channel_axis_raises(self): + """A channel axis more negative than ``-ndim`` is still rejected.""" + p = nn.MaxPool1d(2, channel_axis=-3) + with self.assertRaises(ValueError): + p(brainstate.random.randn(3, 4)) + + def test_most_negative_channel_axis_lppool_and_adaptive(self): + """``channel_axis == -ndim`` is accepted by LPPool and adaptive pooling too.""" + x = brainstate.random.randn(4, 10) + self.assertEqual(nn.LPPool1d(2, 2, channel_axis=-2)(x).shape, (4, 5)) + self.assertEqual(nn.AdaptiveAvgPool1d(5, channel_axis=-2)(x).shape, (4, 5)) + class TestMaxPoolTransforms(unittest.TestCase): """Gradient and jit consistency checks for the max-pooling variants.""" @@ -1387,20 +1465,38 @@ def test_negative_channel_axis_forward(self): out = p(brainstate.random.randn(2, 32, 4)) self.assertEqual(out.shape, (2, 5, 4)) - @pytest.mark.skip( - reason="BUG: AdaptiveAvgPool2d/3d (and Max variants) document a `None` " - "target dim ('Use None for dimensions that should not be pooled') " - "but `_adaptive_pool1d` computes `size % target_size`, raising " - "`TypeError: unsupported operand type(s) for %: 'int' and 'NoneType'` " - "when target_size is None. Repro: " - "nn.AdaptiveAvgPool2d((None, 7), channel_axis=-1)(randn(1, 10, 9, 8))." - ) def test_adaptive_avg_pool_with_none_target_dim(self): """A ``None`` entry in the target size leaves that dimension unchanged.""" p = nn.AdaptiveAvgPool2d((None, 7), channel_axis=-1) out = p(brainstate.random.randn(1, 10, 9, 8)) self.assertEqual(out.shape, (1, 10, 7, 8)) + def test_adaptive_max_pool2d_with_none_target_dim(self): + """``None`` target dims are also supported by the max variant (2d).""" + p = nn.AdaptiveMaxPool2d((None, 7), channel_axis=-1) + out = p(brainstate.random.randn(1, 10, 9, 8)) + self.assertEqual(out.shape, (1, 10, 7, 8)) + + def test_adaptive_avg_pool3d_with_none_target_dims(self): + """Multiple ``None`` target dims leave those axes unchanged (3d).""" + p = nn.AdaptiveAvgPool3d((7, None, None), channel_axis=-1) + out = p(brainstate.random.randn(1, 10, 9, 8, 64)) + self.assertEqual(out.shape, (1, 7, 9, 8, 64)) + + def test_adaptive_avg_pool_none_target_preserves_values(self): + """An unpooled (``None``) axis is left numerically unchanged.""" + # Pool only the last spatial axis; the first spatial axis (None) untouched. + p = nn.AdaptiveAvgPool2d((None, 2), channel_axis=-1) + x = brainstate.random.randn(1, 3, 4, 2) + out = p(x) + self.assertEqual(out.shape, (1, 3, 2, 2)) + # Manually average the width axis (size 4 -> 2) and compare. + expected = jnp.stack( + [jnp.mean(x[:, :, 0:2, :], axis=2), jnp.mean(x[:, :, 2:4, :], axis=2)], + axis=2, + ) + np.testing.assert_allclose(np.asarray(out), np.asarray(expected), rtol=1e-5, atol=1e-6) + class TestAdaptivePoolTransforms(unittest.TestCase): """Gradient and jit consistency checks for the adaptive-pooling variants.""" diff --git a/brainstate/nn/_regularization.py b/brainstate/nn/_regularization.py index f22b2aee..d756bfd7 100644 --- a/brainstate/nn/_regularization.py +++ b/brainstate/nn/_regularization.py @@ -21,6 +21,7 @@ and prevent overfitting. """ +import math from abc import ABC, abstractmethod import brainstate @@ -869,11 +870,21 @@ def loss(self, value: Data) -> Data: remainder = n_elements % self.group_size if remainder != 0: padding = self.group_size - remainder - flat = u.math.concatenate([flat, u.math.zeros(padding)]) + # Build the zero padding with the same unit and dtype as ``flat`` so + # that brainunit Quantity inputs do not raise a UnitMismatchError + # when concatenated (a dimensionless zero array cannot be combined + # with a quantity that carries a physical unit). + pad = u.math.zeros(padding, dtype=u.get_mantissa(flat).dtype) * u.get_unit(flat) + flat = u.math.concatenate([flat, pad]) # Reshape into groups and compute L2 norm of each group groups = u.math.reshape(flat, (-1, self.group_size)) - group_norms = u.math.sqrt(u.math.sum(groups ** 2, axis=1) + 1e-8) + sq_sum = u.math.sum(groups ** 2, axis=1) + # Numerical-stability floor for the sqrt. It must carry the same unit as + # ``sq_sum`` (the squared unit of the input); a dimensionless epsilon + # would raise a UnitMismatchError for brainunit Quantity inputs. + eps = 1e-8 * u.get_unit(sq_sum) + group_norms = u.math.sqrt(sq_sum + eps) return self.weight * u.math.sum(group_norms) @@ -1345,10 +1356,14 @@ def sample_init(self, shape: Size) -> Data: shape_tuple = get_size(shape) if len(shape_tuple) == 1: n = shape_tuple[0] - # Generate random matrix and orthogonalize - random_matrix = brainstate.random.randn(int(u.math.sqrt(float(n))), int(u.math.sqrt(float(n)))) + # Generate a square matrix large enough to yield at least n elements, + # orthogonalize, flatten, then slice to exactly n. Using ceil(sqrt(n)) + # (rather than floor) avoids the reshape ValueError for non-perfect + # squares such as n=5 (floor -> 2x2 -> only 4 elements). + side = int(math.ceil(math.sqrt(float(n)))) + random_matrix = brainstate.random.randn(side, side) q, _ = u.math.linalg.qr(random_matrix) - return u.math.reshape(q, (n,))[:n] + return u.math.reshape(q, (-1,))[:n] elif len(shape_tuple) == 2: m, n = shape_tuple random_matrix = brainstate.random.randn(m, n) @@ -1607,7 +1622,10 @@ def loss(self, value: Data) -> Data: scale = u.math.relu(get_value(self.scale)) + 1e-8 z = value / scale - return self.weight * u.math.sum(u.math.log(1.0 + z ** 2 / df)) + # Student-t negative-log-likelihood data term: 0.5*(df+1)*log(1 + z**2/df). + # The (df+1)/2 factor is df-dependent; without it the penalty wrongly + # vanishes as df -> infinity instead of approaching the Gaussian 0.5*z**2. + return self.weight * u.math.sum(0.5 * (df + 1.0) * u.math.log(1.0 + z ** 2 / df)) def sample_init(self, shape: Size) -> Data: """ diff --git a/brainstate/nn/_regularization_test.py b/brainstate/nn/_regularization_test.py index 0e9c2036..8f96db14 100644 --- a/brainstate/nn/_regularization_test.py +++ b/brainstate/nn/_regularization_test.py @@ -18,6 +18,7 @@ import unittest import brainstate +import brainunit as u import jax.numpy as jnp import numpy as np @@ -343,6 +344,14 @@ def test_group_size_one_like_l1(self): # sqrt(1) + sqrt(4) + sqrt(9) = 1 + 2 + 3 = 6 (approximately, with epsilon) np.testing.assert_allclose(loss, 6.0, rtol=1e-3) + def test_quantity_input_with_padding(self): + """Quantity input whose size is not a multiple of group_size (bug R3).""" + # param size 5 is not a multiple of group_size 2 -> padding needed + reg = GroupLassoReg(weight=1.0, group_size=2) + value = jnp.array([1.0, 0.5, -0.5, 0.2, 0.3]) * u.mV + loss = reg.loss(value) + self.assertTrue(np.isfinite(float(u.get_mantissa(loss)))) + def test_sample_init_shape(self): """Test sample_init returns correct shape.""" reg = GroupLassoReg(weight=1.0, group_size=2) @@ -514,6 +523,13 @@ def test_sample_init_2d_shape(self): sample = reg.sample_init((4, 3)) self.assertEqual(sample.shape, (4, 3)) + def test_sample_init_1d_non_perfect_square(self): + """1D non-perfect-square shapes must not raise (bug R2).""" + reg = OrthogonalReg(weight=1.0) + for n in (5, 7): + sample = reg.sample_init((n,)) + self.assertEqual(sample.shape, (n,)) + def test_reset_value(self): """Test reset_value returns zero.""" reg = OrthogonalReg(weight=1.0) @@ -622,6 +638,31 @@ def test_basic_loss(self): # At x=0, loss = log(1 + 0) = 0 np.testing.assert_allclose(loss, 0.0, atol=1e-5) + def test_nll_factor_df3(self): + """Per-element penalty must include the (df+1)/2 factor (bug R1).""" + reg = StudentTReg(weight=1.0, df=3.0, scale=1.0) + value = jnp.array([2.0]) + loss = reg.loss(value) + # Correct Student-t NLL data term: 0.5*(df+1)*log(1 + z**2/df) + # df=3, z=2 -> 0.5*4*log(1 + 4/3) = 2*log(7/3) ~ 1.6946 + expected = 0.5 * (3.0 + 1.0) * np.log(1.0 + 4.0 / 3.0) + np.testing.assert_allclose(float(loss), expected, rtol=1e-5) + # And it must NOT equal the unfactored value ~0.847 + unfactored = np.log(1.0 + 4.0 / 3.0) + self.assertGreater(float(loss), unfactored * 1.5) + + def test_gaussian_limit_large_df(self): + """As df -> infinity the penalty approaches the Gaussian 0.5*z**2 (bug R1).""" + # Use float64 so the assertion exercises the math, not float32 rounding + # of log(1 + 4e-6). Under float32 the value is ~2.027 (a precision + # artifact); the limit itself is correct. + with brainstate.environ.context(precision=64): + reg = StudentTReg(weight=1.0, df=1e6, scale=1.0) + value = jnp.asarray([2.0], dtype=brainstate.environ.dftype()) + loss = reg.loss(value) + gaussian_limit = 0.5 * 2.0 ** 2 # 2.0 + np.testing.assert_allclose(float(loss), gaussian_limit, rtol=1e-3) + def test_loss_increases_with_distance(self): """Test that loss increases with distance from zero.""" reg = StudentTReg(weight=1.0, df=3.0, scale=1.0) diff --git a/brainstate/nn/_transform.py b/brainstate/nn/_transform.py index 23cc997e..8ecdb7e6 100644 --- a/brainstate/nn/_transform.py +++ b/brainstate/nn/_transform.py @@ -397,9 +397,19 @@ def log_abs_det_jacobian(self, x: ArrayLike, y: ArrayLike) -> Array: For sigmoid: d/dx[lower + width * sigmoid(x)] = width * sigmoid(x) * (1 - sigmoid(x)) log|det J| = sum(log(width) + log(sigmoid(x)) + log(1 - sigmoid(x))) + + Notes + ----- + ``width`` may carry physical units (when ``lower``/``upper`` are + :class:`~brainunit.Quantity`). The log-determinant is a dimensionless + log-density correction, so the unit is stripped from ``width`` via + :func:`brainunit.get_mantissa` before taking the logarithm. The numerically + stable identity ``log(sigmoid(x)) + log(1 - sigmoid(x)) = log_sigmoid(x) + + log_sigmoid(-x)`` avoids ``log(0)`` for large ``|x|``. """ - s = jax.nn.sigmoid(x) - return jnp.sum(jnp.log(self.width) + jnp.log(s) + jnp.log(1 - s), axis=-1) + log_width = jnp.log(u.get_mantissa(self.width)) + log_s = jax.nn.log_sigmoid(x) + jax.nn.log_sigmoid(-x) + return jnp.sum(log_width + log_s, axis=-1) class SoftplusT(Transform): @@ -490,10 +500,11 @@ def forward(self, x: ArrayLike) -> Array: Notes ----- - Uses log1p for numerical stability: log1p(exp(x)) = log(1 + exp(x)). - For large x, this avoids overflow in the exponential. + Uses ``jax.nn.softplus`` for numerical stability: softplus(x) = log(1 + exp(x)), + which is exact for all ``x`` (the previous ``log1p(exp(x))`` form saturated for + ``x`` beyond ~20, breaking the forward map and round-trip). """ - return jnp.log1p(save_exp(x)) * self.unit + self.lower + return jax.nn.softplus(x) * self.unit + self.lower def inverse(self, y: ArrayLike) -> Array: """ @@ -512,9 +523,14 @@ def inverse(self, y: ArrayLike) -> Array: Notes ----- Input must be strictly greater than lower bound to avoid numerical issues. - Uses numerically stable exponential for large (y - lower) values. + Uses the numerically stable softplus inverse ``z + log(-expm1(-z))`` which is + accurate across the whole range (the previous ``log(exp(z) - 1)`` form clipped + ``z`` at ~20 and so failed to invert large constrained values). ``z`` is + dimensionless after dividing out ``self.unit``, so its mantissa is taken to + keep the bare-``jnp`` log/expm1 calls valid. """ - return u.math.log(save_exp((y - self.lower) / self.unit) - 1.0) + z = u.get_mantissa((y - self.lower) / self.unit) + return z + jnp.log(-jnp.expm1(-z)) def log_abs_det_jacobian(self, x: ArrayLike, y: ArrayLike) -> Array: r""" @@ -523,7 +539,7 @@ def log_abs_det_jacobian(self, x: ArrayLike, y: ArrayLike) -> Array: For softplus: d/dx[log(1 + exp(x))] = sigmoid(x) log|det J| = sum(log(sigmoid(x))) = sum(x - softplus(x)) """ - return jnp.sum(x - jnp.log1p(save_exp(x)), axis=-1) + return jnp.sum(jax.nn.log_sigmoid(x), axis=-1) class NegSoftplusT(SoftplusT): @@ -615,9 +631,9 @@ def forward(self, x: ArrayLike) -> Array: Notes ----- - Implemented as: upper - softplus(-x). + Implemented as: upper - softplus(-x), using the stable ``jax.nn.softplus``. """ - return self.lower - jnp.log1p(save_exp(-x)) * self.unit + return self.lower - jax.nn.softplus(-x) * self.unit def inverse(self, y: ArrayLike) -> Array: """ @@ -636,9 +652,11 @@ def inverse(self, y: ArrayLike) -> Array: Notes ----- Inverts: y = upper - softplus(-x) => x = -softplus^{-1}(upper - y). + ``s`` is dimensionless after dividing out ``self.unit``, so its mantissa is + taken to keep the bare-``jnp`` log/expm1 calls valid. """ - s = (self.lower - y) / self.unit - return -u.math.log(save_exp(s) - 1.0) + s = u.get_mantissa((self.lower - y) / self.unit) + return -(s + jnp.log(-jnp.expm1(-s))) class LogT(Transform): @@ -654,6 +672,19 @@ class LogT(Transform): ---------- lower : array_like Lower bound of the target interval. + + Notes + ----- + .. important:: + + ``forward`` uses :func:`save_exp`, which clips the exponent at + ``max_value=20`` for numerical stability. Inputs with ``x > 20`` therefore + saturate at ``lower + exp(20) * unit`` (``exp(20) ≈ 4.85e8``) and the + transform is **not invertible** in that regime — ``inverse(forward(x))`` + will not round-trip for ``x > 20``. The analytic ``log_abs_det_jacobian`` + (``sum(x)``) is likewise only valid where the exponential is unclipped. + Keep the unconstrained representation within roughly ``[-20, 20]`` to stay + in the bijective region. """ __module__ = 'brainstate.nn' @@ -681,6 +712,16 @@ class ExpT(Transform): Exponential transformation mapping (-inf, +inf) to (lower, +inf). Equivalent to Log; provided for explicit naming. + + Notes + ----- + .. important:: + + ``forward`` uses :func:`save_exp`, which clips the exponent at + ``max_value=20``. Inputs with ``x > 20`` saturate at + ``lower + exp(20) * unit`` and the transform is **not invertible** there; + the analytic ``log_abs_det_jacobian`` (``sum(x)``) is only valid in the + unclipped region. Keep inputs within roughly ``[-20, 20]``. """ __module__ = 'brainstate.nn' @@ -930,7 +971,10 @@ def __init__(self, scale: ArrayLike, shift: ArrayLike): If scale is zero or numerically close to zero, making the transformation non-invertible. """ - if jnp.allclose(scale, 0): + # Compare on the mantissa so a unit-carrying scale (e.g. unit conversions) + # does not break the zero check; invertibility only requires a non-zero + # magnitude. + if jnp.allclose(u.get_mantissa(scale), 0): raise ValueError("a cannot be zero, must be invertible") self.a = scale self.b = shift @@ -971,9 +1015,27 @@ def inverse(self, x: ArrayLike) -> Array: return (x - self.b) / self.a def log_abs_det_jacobian(self, x: ArrayLike, y: ArrayLike) -> Array: - """For affine: d/dx[ax + b] = a, so log|det J| = n * log|a|.""" - n = jnp.shape(x)[-1] if jnp.ndim(x) > 0 else 1 - return n * jnp.log(jnp.abs(self.a)) + r""" + Compute log absolute determinant of the Jacobian. + + For affine ``y = a * x + b`` the Jacobian is diagonal with entries ``a``, + so ``log|det J| = sum_i log|a_i|`` over the event (last) axis. + + Notes + ----- + - The scale ``a`` may be a scalar or a per-dimension array; it is broadcast + against ``x`` and summed over the last axis, so a batched input of shape + ``(B, n)`` yields a ``(B,)`` result rather than a single scalar. + - ``a`` may carry physical units (unit conversions). The log-determinant is + a dimensionless log-density correction, so the unit is stripped via + :func:`brainunit.get_mantissa` before taking the logarithm. + """ + log_a = jnp.log(jnp.abs(u.get_mantissa(self.a))) + if jnp.ndim(x) == 0: + return jnp.sum(log_a) + # Broadcast the (possibly per-dimension) scale across x and contract the + # event axis, preserving any leading batch dimensions. + return jnp.sum(jnp.broadcast_to(log_a, jnp.shape(x)), axis=-1) class ChainT(Transform): @@ -1285,6 +1347,16 @@ class PositiveT(Transform): .. math:: \text{inverse}(y) = \log(y) + Notes + ----- + .. important:: + + ``forward`` uses :func:`save_exp`, which clips the exponent at + ``max_value=20``. Inputs with ``x > 20`` saturate at ``exp(20) ≈ 4.85e8`` + and the transform is **not invertible** there; the analytic + ``log_abs_det_jacobian`` (``sum(x)``) is only valid in the unclipped + region. Keep inputs within roughly ``[-20, 20]``. + Examples -------- >>> transform = PositiveT() @@ -1349,11 +1421,12 @@ def __repr__(self) -> str: def forward(self, x: ArrayLike) -> Array: """Transform unbounded input to negative values.""" - return -jnp.log1p(save_exp(-x)) + return -jax.nn.softplus(-x) def inverse(self, y: ArrayLike) -> Array: """Transform negative input back to unbounded domain.""" - return -u.math.log(save_exp(-y) - 1.0) + s = -u.get_mantissa(y) + return -(s + jnp.log(-jnp.expm1(-s))) class ScaledSigmoidT(Transform): @@ -1539,14 +1612,15 @@ def __repr__(self) -> str: def forward(self, x: ArrayLike) -> Array: """Transform unconstrained input to ordered vectors.""" first = x[..., :1] - rest = jnp.log1p(save_exp(x[..., 1:])) + rest = jax.nn.softplus(x[..., 1:]) return jnp.concatenate([first, first + jnp.cumsum(rest, axis=-1)], axis=-1) def inverse(self, y: ArrayLike) -> Array: """Transform ordered vectors back to unconstrained domain.""" first = y[..., :1] diffs = y[..., 1:] - y[..., :-1] - rest = u.math.log(u.math.exp(diffs) - 1) + # Stable softplus inverse of the positive gaps. + rest = diffs + jnp.log(-jnp.expm1(-diffs)) return jnp.concatenate([first, rest], axis=-1) diff --git a/brainstate/nn/_transform_test.py b/brainstate/nn/_transform_test.py index 57433c4d..d6d85d9f 100644 --- a/brainstate/nn/_transform_test.py +++ b/brainstate/nn/_transform_test.py @@ -448,5 +448,76 @@ def test_clip_repr(self): self.assertIn("upper=1.0", repr(t)) +class TestTransformAuditRegressions(unittest.TestCase): + """Regression tests for audit findings T2-T5 (transform numerics/units).""" + + def test_softplus_roundtrip_large_constrained_value(self): + # T2/T3: the old log1p(save_exp(x)) forward clipped at exp(20), so large + # unconstrained inputs saturated and inverse(forward(x)) failed to round-trip. + t = SoftplusT(0.0) + x = jnp.array([25.0, 30.0, 50.0]) # all beyond the old save_exp clip (20) + y = t.forward(x) + # forward should be ~identity for large x (softplus(x) ≈ x), not clipped. + np.testing.assert_allclose(u.get_mantissa(y), np.asarray(x), rtol=1e-4) + xr = t.inverse(y) + np.testing.assert_allclose(np.asarray(xr), np.asarray(x), rtol=1e-4) + + def test_negsoftplus_roundtrip_large_value(self): + # T2/T3: same saturation bug in the reflected transform. + t = NegSoftplusT(0.0) + x = jnp.array([25.0, 40.0]) + y = t.forward(x) + xr = t.inverse(y) + np.testing.assert_allclose(np.asarray(xr), np.asarray(x), rtol=1e-4) + + def test_softplus_roundtrip_with_units(self): + # T4: dividing out the unit yields a dimensionless Quantity; inverse must + # strip it so bare jnp.expm1/log do not choke on a Quantity. + t = SoftplusT(0.0 * u.mV) + x = jnp.array([-3.0, 0.0, 5.0, 25.0]) + y = t.forward(x) + self.assertEqual(u.get_unit(y), u.mV) + xr = t.inverse(y) + np.testing.assert_allclose(np.asarray(xr), np.asarray(x), rtol=1e-4) + + def test_sigmoid_log_abs_det_jacobian_with_units(self): + # T4: log_abs_det_jacobian crashed when bounds carried units because + # jnp.log was applied to a Quantity width. It must now run and be finite. + t = SigmoidT(0.0 * u.mV, 10.0 * u.mV) + x = jnp.array([-2.0, 0.0, 3.0]) + ladj = t.log_abs_det_jacobian(x, t.forward(x)) + self.assertTrue(np.all(np.isfinite(np.asarray(ladj)))) + # large |x| must not produce -inf/nan (stable log_sigmoid form). + ladj_big = t.log_abs_det_jacobian(jnp.array([100.0, -100.0]), None) + self.assertTrue(np.all(np.isfinite(np.asarray(ladj_big)))) + + def test_affine_log_abs_det_jacobian_batch_shape(self): + # T5: batched input (B, n) must yield a (B,) log-det, not a single scalar. + t = AffineT(2.0, 1.0) + x = jnp.ones((4, 3)) + ladj = t.log_abs_det_jacobian(x, t.forward(x)) + self.assertEqual(np.shape(ladj), (4,)) + np.testing.assert_allclose(np.asarray(ladj), 3 * np.log(2.0), rtol=1e-6) + + def test_affine_log_abs_det_jacobian_array_scale(self): + # T5: per-dimension scale must contract over the event axis as sum(log|a_i|). + a = jnp.array([2.0, 3.0, 4.0]) + t = AffineT(a, 0.0) + x = jnp.ones((5, 3)) + ladj = t.log_abs_det_jacobian(x, t.forward(x)) + self.assertEqual(np.shape(ladj), (5,)) + np.testing.assert_allclose( + np.asarray(ladj), np.sum(np.log(np.abs(np.asarray(a)))), rtol=1e-6 + ) + + def test_affine_log_abs_det_jacobian_with_units(self): + # T4: unit-carrying scale must not crash jnp.log. + t = AffineT(2.0 * u.mV, 0.0 * u.mV) + x = jnp.array([1.0, 2.0]) + ladj = t.log_abs_det_jacobian(x, None) + self.assertTrue(np.all(np.isfinite(np.asarray(ladj)))) + np.testing.assert_allclose(np.asarray(ladj), 2 * np.log(2.0), rtol=1e-6) + + if __name__ == '__main__': unittest.main() diff --git a/brainstate/nn/_utils.py b/brainstate/nn/_utils.py index fc4d6e25..397d38d7 100644 --- a/brainstate/nn/_utils.py +++ b/brainstate/nn/_utils.py @@ -149,6 +149,15 @@ def clip_grad_norm( where :math:`\\|g\\|_p` is the p-norm of the concatenated gradient vector. + Notes + ----- + Gradients are flattened and concatenated with plain array ops, so this function + assumes **unitless** (dimensionless) gradients. ``brainunit.Quantity`` gradients + carrying physical units will have their units stripped in the returned norm, and + gradients with *different* units across leaves cannot be combined meaningfully. + Strip units (e.g. via ``u.get_mantissa``) before clipping if your gradients carry + units. + Examples -------- .. code-block:: python diff --git a/brainstate/nn/init.py b/brainstate/nn/init.py index 23acc92e..32667ee6 100644 --- a/brainstate/nn/init.py +++ b/brainstate/nn/init.py @@ -347,10 +347,11 @@ class TruncatedNormal(Initializer): The standard deviation of the normal distribution before truncating. lower : float, ndarray A float or array of floats representing the lower bound for - truncation. Must be broadcast-compatible with ``upper``. + truncation. Must be broadcast-compatible with ``upper``. Default is ``-2.0`` + (in units of standard deviations before scaling). upper : float, ndarray A float or array of floats representing the upper bound for - truncation. Must be broadcast-compatible with ``lower``. + truncation. Must be broadcast-compatible with ``lower``. Default is ``2.0``. """ __module__ = 'brainstate.nn' @@ -360,8 +361,8 @@ def __init__( loc: ArrayLike = 0., scale: ArrayLike = 1., unit: u.Unit = u.UNITLESS, - lower: ArrayLike = None, - upper: ArrayLike = None, + lower: ArrayLike = -2.0, + upper: ArrayLike = 2.0, seed: SeedOrKey = None, ): super().__init__() diff --git a/brainstate/nn/init_test.py b/brainstate/nn/init_test.py index 34fd9475..ffa90d4b 100644 --- a/brainstate/nn/init_test.py +++ b/brainstate/nn/init_test.py @@ -564,12 +564,14 @@ def test_bounds_respected(self): self.assertGreaterEqual(arr.min(), -1.001) self.assertLessEqual(arr.max(), 1.001) - @pytest.mark.skip(reason="BUG: TruncatedNormal() with default lower/upper=None " - "crashes in random.truncated_normal (None - array)") def test_default_bounds(self): - """``TruncatedNormal`` should work with default (None) bounds.""" + """L3: ``TruncatedNormal`` should work with its default bounds (no crash).""" out = brainstate.nn.init.TruncatedNormal()((4,)) self.assertEqual(out.shape, (4,)) + # Default bounds are +/- 2 standard deviations. + arr = np.asarray(brainstate.nn.init.TruncatedNormal(seed=0)((500,))) + self.assertGreaterEqual(arr.min(), -2.001) + self.assertLessEqual(arr.max(), 2.001) class TestGammaExponential(unittest.TestCase):