Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions brainstate/nn/_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def soft_shrink(x: ArrayLike, lambd: float = 0.5) -> Union[jax.Array, u.Quantity
u.math.where(
x < -lambd,
x + lambd,
u.Quantity(0., unit=u.get_unit(lambd))
u.Quantity(0., unit=u.get_unit(x))
)
)

Expand Down Expand Up @@ -279,8 +279,12 @@ def rrelu(x: ArrayLike, lower: float = 0.125, upper: float = 0.3333333333333333)
.. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations
in Convolutional Network." arXiv:1505.00853
"""
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
# Derive a float dtype robustly: ``x`` may be a python scalar or an integer
# array (neither of which has a usable float ``.dtype`` for sampling).
mantissa = u.get_mantissa(x)
sample_dtype = jax.numpy.result_type(mantissa, float)
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=sample_dtype)
return u.math.where(mantissa >= 0., x, a * x)


def hard_shrink(x: ArrayLike, lambd: float = 0.5) -> Union[jax.Array, u.Quantity]:
Expand Down
21 changes: 21 additions & 0 deletions brainstate/nn/_activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,5 +350,26 @@ def f(hx, _):
fwd() # doesn't crash


class TestActivationAuditRegressions(parameterized.TestCase):
"""Regression tests for bugs found in the nn-module audit."""

def test_rrelu_accepts_python_scalar(self):
"""A2: rrelu must not crash on a python float (which has no .dtype)."""
with brainstate.random.seed_context(0):
out = brainstate.nn.rrelu(1.0)
self.assertTrue(np.isfinite(np.asarray(out)).all())
out_neg = brainstate.nn.rrelu(-1.0)
self.assertTrue(np.isfinite(np.asarray(out_neg)).all())

def test_rrelu_accepts_integer_array(self):
"""A2: rrelu must not crash on an integer-typed array."""
with brainstate.random.seed_context(0):
x = jnp.array([-2, -1, 0, 1, 2], dtype=jnp.int32)
out = np.asarray(brainstate.nn.rrelu(x))
self.assertEqual(out.shape, (5,))
# Non-negative entries pass through unchanged.
np.testing.assert_allclose(out[2:], np.array([0., 1., 2.]), rtol=1e-6)


if __name__ == '__main__':
absltest.main()
75 changes: 47 additions & 28 deletions brainstate/nn/_collective_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@

import jax

from brainstate._state import catch_new_states
from brainstate._utils import set_module_as
from brainstate.graph import nodes
from brainstate.transform import vmap, vmap_new_states
from brainstate.transform import vmap_new_states
from brainstate.typing import Filter
from ._module import Module

Expand Down Expand Up @@ -274,24 +273,25 @@ def vmap_call_all_fns(
if not isinstance(kwargs, Mapping):
raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')

@vmap(axis_size=axis_size)
def vmapped_fn():
with catch_new_states(state_tag) as inner_catcher:
call_all_fns(
target,
fn_name=fn_name,
args=args,
kwargs=kwargs,
node_to_exclude=node_to_exclude,
fn_if_not_exist=fn_if_not_exist
)
return inner_catcher.get_state_values()

with catch_new_states(state_tag) as outer_catcher:
values = vmapped_fn()
states = outer_catcher.get_states()
for state, value in zip(states, values):
state.value = value
# Delegate to ``vmap_new_states``, the same transform used by
# ``vmap_init_all_states``. The previous hand-rolled implementation paired an
# inner ``catch_new_states`` (returning per-lane values) with an outer one
# (returning the State objects) and wrote the values back manually. Because the
# States were created *inside* the vmap trace, the write-back could commit a
# ``BatchTracer`` into ``state.value`` and the batched leading axis was never
# materialized, raising ``UnexpectedTracerError`` on later use. ``vmap_new_states``
# handles new-state batching correctly in a single, well-tested pass.
def call_fn():
call_all_fns(
target,
fn_name=fn_name,
args=args,
kwargs=kwargs,
node_to_exclude=node_to_exclude,
fn_if_not_exist=fn_if_not_exist,
)

vmap_new_states(call_fn, state_tag=state_tag, axis_size=axis_size)()
return target


Expand Down Expand Up @@ -631,14 +631,33 @@ def assign_state_values(

# Get current module states
variables = target.states()
keys1 = set(all_states.keys())
keys2 = set(variables.keys())

# Update matching states
for key in keys2.intersection(keys1):
variables[key].value = jax.numpy.asarray(all_states[key])
# Normalize keys so both tuple paths (``('layer1', 'weight')`` — the form
# ``states()`` returns) and the documented dotted-string paths
# (``'layer1.weight'``) compare equal (M2). We keep the *original* key
# objects for the returned ``unexpected``/``missing`` lists so callers that
# rely on tuple paths keep working.
def _norm_key(k):
return '.'.join(str(p) for p in k) if isinstance(k, tuple) else str(k)

var_by_norm = {_norm_key(k): k for k in variables.keys()}
incoming_by_norm = {_norm_key(k): k for k in all_states.keys()}

# Update matching states. ``jax.tree.map`` preserves pytree-structured values
# (e.g. dict-valued states) and unit-carrying ``Quantity`` values, both of
# which ``jax.numpy.asarray`` would reject (M1).
for norm in set(var_by_norm).intersection(incoming_by_norm):
variables[var_by_norm[norm]].value = jax.tree.map(
jax.numpy.asarray, all_states[incoming_by_norm[norm]]
)

# Return mismatched keys
unexpected_keys = sorted(keys1 - keys2)
missing_keys = sorted(keys2 - keys1)
# Return mismatched keys in their original (caller-supplied / module) form.
unexpected_keys = sorted(
(incoming_by_norm[n] for n in set(incoming_by_norm) - set(var_by_norm)),
key=_norm_key,
)
missing_keys = sorted(
(var_by_norm[n] for n in set(var_by_norm) - set(incoming_by_norm)),
key=_norm_key,
)
return unexpected_keys, missing_keys
59 changes: 51 additions & 8 deletions brainstate/nn/_collective_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,10 +827,6 @@ def test_call_completes_and_runs_body(self):
self.assertIs(returned, module)
self.assertTrue(hasattr(module, 'w'))

@pytest.mark.skip(reason="BUG: vmap_call_all_fns leaks a JAX BatchTracer into "
"newly created state values; the batched value is not "
"committed (shape stays per-lane and later use raises "
"UnexpectedTracerError). See report.")
def test_batched_init_creates_leading_axis(self):
"""Vmapped init_state should create a committed leading batch axis."""
module = EnsembleModule()
Expand All @@ -840,8 +836,6 @@ def test_batched_init_creates_leading_axis(self):
)
self.assertEqual(module.w.value.shape, (_testing.SMALL_BATCH, 3))

@pytest.mark.skip(reason="BUG: vmap_call_all_fns leaks a JAX BatchTracer into "
"newly created state values. See report.")
def test_batched_init_distinct_per_lane(self):
"""Each batch lane should receive an independent random initialization."""
module = EnsembleModule()
Expand All @@ -851,8 +845,6 @@ def test_batched_init_distinct_per_lane(self):
)
self.assertFalse(bool(jnp.allclose(module.w.value[0], module.w.value[1])))

@pytest.mark.skip(reason="BUG: vmap_call_all_fns leaks a JAX BatchTracer into "
"newly created state values. See report.")
def test_positional_single_arg_wrapped(self):
"""A single non-tuple positional argument is wrapped in a tuple."""

Expand Down Expand Up @@ -1011,3 +1003,54 @@ def test_multiple_dicts_merge_last_wins(self):
{target_key: jnp.ones(3) * 7.0},
)
_testing.assert_allclose(net.states()[target_key].value, jnp.ones(3) * 7.0)

# --- Audit regressions (M1: pytree/unit values; M2: dotted-string keys) ---

def test_assign_dict_valued_state(self):
"""M1: a state whose value is a dict must round-trip without crashing."""
import brainunit as u

class DictState(brainstate.nn.Module):
def init_state(self):
self.d = brainstate.ParamState({'a': jnp.ones(2), 'b': jnp.zeros(3)})

net = DictState()
brainstate.nn.init_all_states(net)
key = ('d',)
unexpected, missing = brainstate.nn.assign_state_values(
net, {key: {'a': jnp.ones(2) * 5.0, 'b': jnp.ones(3) * 2.0}}
)
self.assertEqual(unexpected, [])
self.assertEqual(missing, [])
_testing.assert_allclose(net.states()[key].value['a'], jnp.ones(2) * 5.0)
_testing.assert_allclose(net.states()[key].value['b'], jnp.ones(3) * 2.0)

def test_assign_quantity_valued_state(self):
"""M1: a state whose value carries physical units must keep its unit."""
import brainunit as u

class QtyState(brainstate.nn.Module):
def init_state(self):
self.v = brainstate.ParamState(jnp.ones(3) * u.mV)

net = QtyState()
brainstate.nn.init_all_states(net)
key = ('v',)
unexpected, missing = brainstate.nn.assign_state_values(
net, {key: jnp.ones(3) * 7.0 * u.mV}
)
self.assertEqual((unexpected, missing), ([], []))
restored = net.states()[key].value
self.assertEqual(u.get_unit(restored), u.mV)
_testing.assert_allclose(u.get_mantissa(restored), jnp.ones(3) * 7.0)

def test_assign_accepts_dotted_string_keys(self):
"""M2: dotted-string keys (as documented) must match tuple state paths."""
net = self._make_net()
unexpected, missing = brainstate.nn.assign_state_values(
net, {'w': jnp.ones(3) * 3.0, 'b': jnp.ones(2) * 4.0}
)
self.assertEqual(unexpected, [])
self.assertEqual(missing, [])
_testing.assert_allclose(net.states()[('w',)].value, jnp.ones(3) * 3.0)
_testing.assert_allclose(net.states()[('b',)].value, jnp.ones(2) * 4.0)
26 changes: 16 additions & 10 deletions brainstate/nn/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def _filter_states(
filtered_states = None
elif isinstance(filters, dict):
in_states_filter = defaultdict(list)
for filter_, axis in filters:
# ``filters`` maps ``{filter: axis}`` (see Parameters). Iterate ``.items()``
# so the filter and its axis are unpacked correctly; iterating the dict
# directly yields only keys and breaks the documented form.
for filter_, axis in filters.items():
assert isinstance(axis, int), 'The value of in_states must be the map axis, which should be an integer.'
in_states_filter[axis].append(filter_)
filtered_states = module.states(*in_states_filter.values())
Expand Down Expand Up @@ -564,24 +567,27 @@ def update(self, *args, **kwargs):
raise ValueError(
'Map.update called before init_all_states. Please call init_all_states first.'
)
map_kwargs = dict(
in_axes=self.in_axes,
out_axes=self.out_axes,
axis_name=self.axis_name,
state_in_axes=self._call_state_axes,
state_out_axes=self._call_state_axes,
)
if self.behavior == 'vmap':
map_fn = vmap2
# ``spmd_axis_name`` is a vmap-only concept (nested vmap-over-pmap SPMD).
map_kwargs['spmd_axis_name'] = self.spmd_axis_name
elif self.behavior == 'pmap':
map_fn = pmap2
# ``pmap2`` has no ``spmd_axis_name`` parameter; forwarding it raises a
# TypeError, so it is intentionally omitted here.
else:
raise ValueError(
'Invalid behavior specified. Must be "vmap" or "pmap".'
)

return map_fn(
self.module,
in_axes=self.in_axes,
out_axes=self.out_axes,
axis_name=self.axis_name,
spmd_axis_name=self.spmd_axis_name,
state_in_axes=self._call_state_axes,
state_out_axes=self._call_state_axes,
)(*args, **kwargs)
return map_fn(self.module, **map_kwargs)(*args, **kwargs)

def map(
self,
Expand Down
41 changes: 21 additions & 20 deletions brainstate/nn/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def test_update_with_explicit_context(self):
class TestFilterStatesDict(unittest.TestCase):
"""Exercise the dictionary branch of ``_filter_states``."""

def test_filter_states_dict_tuple_key(self):
"""Map a (filter, axis) tuple key to its selected states by axis."""
def test_filter_states_dict_multiple_axes(self):
"""M3: documented ``{filter: axis}`` mapping selects states per axis."""

class M(Module):
"""Module carrying one ParamState and one ShortTermState."""
Expand All @@ -192,17 +192,18 @@ def __init__(self):
self.s = brainstate.ShortTermState(jnp.zeros(3))

module = M()
# The implementation iterates over the dict's keys and unpacks each into
# (filter_, axis); supply a 2-tuple key so the unpack succeeds.
result = _filter_states(module, {(OfType(brainstate.ParamState), 0): 'ignored'})
# Two filters mapping to two distinct axes (the documented form).
result = _filter_states(
module,
{OfType(brainstate.ParamState): 0, OfType(brainstate.ShortTermState): 1},
)
self.assertIn(0, result)
self.assertIn(1, result)
self.assertEqual(len(result[0]), 1)
self.assertEqual(len(result[1]), 1)

@pytest.mark.skip(reason="BUG: _filter_states dict branch unpacks dict keys "
"instead of items; documented {filter: axis} form "
"raises TypeError. See report.")
def test_filter_states_dict_documented_form(self):
"""Documented ``{filter: axis}`` mapping should select states by axis."""
"""M3: documented ``{filter: axis}`` mapping should select states by axis."""

class M(Module):
"""Module carrying one ParamState."""
Expand All @@ -215,6 +216,7 @@ def __init__(self):
module = M()
result = _filter_states(module, {OfType(brainstate.ParamState): 0})
self.assertIn(0, result)
self.assertEqual(len(result[0]), 1)


class TestToPredicate(unittest.TestCase):
Expand Down Expand Up @@ -439,26 +441,25 @@ def test_pmap_init_and_map_call(self):
raise
self.assertEqual(out.shape, (1, 2))

def test_pmap_update_is_broken(self):
"""pmap-mode ``update`` raises TypeError from an invalid pmap2 kwarg.
def test_pmap_update_runs(self):
"""M5: pmap-mode ``update`` runs (``spmd_axis_name`` is not forwarded to pmap2).

``Map.update`` always forwards ``spmd_axis_name`` to the map function,
but ``pmap2`` does not accept that keyword. This documents the bug while
keeping the suite green; see report.
Previously ``Map.update`` always forwarded ``spmd_axis_name`` to the map
function, but ``pmap2`` does not accept that keyword, so pmap-mode
``update`` raised ``TypeError``. It must now run and produce the batched
output shape.
"""
m = Map(_ParamModule(), init_map_size=1, behavior='pmap')
try:
# ``init_all_states`` already drives the pmap path; jax < 0.10 rejects
# ordered effects there, before the TypeError this test documents.
# jax < 0.10 rejects ordered effects inside pmap (raised as early as
# init_all_states), so the whole pmap interaction is guarded.
m.init_all_states(din=3, dout=2)
m.update(jnp.ones((1, 3)))
except TypeError:
return # expected: pmap2 rejects the forwarded spmd_axis_name kwarg
out = m.update(jnp.ones((1, 3)))
except ValueError as e:
if 'Ordered effects' in str(e):
self.skipTest(f'pmap on this JAX version rejects ordered effects: {e}')
raise
self.fail('expected a TypeError from the invalid pmap2 kwarg, but no error was raised')
self.assertEqual(out.shape, (1, 2))


class TestMapCaller(unittest.TestCase):
Expand Down
Loading