Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions brainstate/transform/_mapping1.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def _vmap_transform(
mapping_fn=functools.partial(jax.vmap, spmd_axis_name=spmd_axis_name),
unexpected_out_state_mapping='raise',
name='vmap',
# #5: speak the legacy vmap vocabulary (in_states/out_states, no policy
# knob) in the undeclared-write error instead of engine internals.
out_decl_name='out_states',
out_decl_extra='',
)

@functools.wraps(f)
Expand Down
18 changes: 18 additions & 0 deletions brainstate/transform/_mapping1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,24 @@ def fn(x):
with self.assertRaises(BatchAxisError):
fn(xs)

def test_undeclared_write_error_uses_vmap_vocabulary(self):
# #5: the legacy vmap API uses in_states/out_states, not the engine's
# state_out_axes / unexpected_out_state_mapping. The undeclared-write
# error must speak the caller's vocabulary, not engine internals.
state = bst.ShortTermState(jnp.zeros(3))

@vmap(in_axes=0)
def fn(x):
state.value = state.value + x # batched write, not in out_states
return x

with self.assertRaises(BatchAxisError) as cm:
fn(jnp.arange(3.0))
msg = str(cm.exception)
self.assertIn('out_states', msg)
self.assertNotIn('state_out_axes', msg)
self.assertNotIn('unexpected_out_state_mapping', msg)


class TestVmapNewStates(unittest.TestCase):
"""Test vmap_new_states functionality."""
Expand Down
74 changes: 60 additions & 14 deletions brainstate/transform/_mapping2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union, TypeVar

import jax
import jax.numpy as jnp

from brainstate._compatible_import import Device
from brainstate._state import State, catch_new_states, StateTraceStack, TRACE_CONTEXT
Expand Down Expand Up @@ -104,12 +105,28 @@ class StatefulMapping:
same selector semantics as ``state_in_axes``.
unexpected_out_state_mapping : {'auto', 'raise', 'warn', 'ignore'}, default 'auto'
Policy for states written with a batched value but not covered by
``state_out_axes``. ``'auto'`` scatters them at their detected axis,
``'raise'`` raises a :class:`~brainstate._error.BatchAxisError`,
``'warn'`` scatters them with a warning, and ``'ignore'`` scatters them
``state_out_axes``. ``'auto'`` scatters them at their detected axis;
``'raise'`` raises a :class:`~brainstate._error.BatchAxisError`;
``'warn'`` scatters them with a warning; and ``'ignore'`` scatters them
silently.

Under ``'auto'``, an undeclared state that is *read and written*
(read-modify-write) is handled specially: on each call, if its current
leading size along the detected axis already equals the mapped size it is
fed **per lane** (a per-lane RMW buffer) instead of being broadcast and
scattered. This decision is re-made every call from the live value, so a
cached (warm) call behaves identically to a fresh (cold) one. Because the
choice rests on a leading-size match, a state whose leading dimension
*coincidentally* equals the mapped size is treated as per-lane even if you
meant it as a shared, broadcast value -- declare it explicitly via
``state_in_axes``/``state_out_axes`` to remove the ambiguity. When such an
RMW state's leading size does *not* match the mapped size it is scattered
(gaining a new leading axis) and a one-time :class:`UserWarning` is
emitted.
static_argnums : int or iterable of int, default ()
Positional arguments treated as compile-time constants for caching.
Positional arguments treated as compile-time constants. The argument is
closed over (as in :func:`jax.jit`) and is neither traced nor mapped, so
its ``in_axes`` entry, if any, is ignored.
static_argnames : str or iterable of str, default ()
Keyword arguments treated as compile-time constants for caching.
axis_env : sequence, optional
Expand Down Expand Up @@ -312,7 +329,13 @@ def vmap2(
unexpected_out_state_mapping : {'auto', 'raise', 'warn', 'ignore'}, default 'auto'
Policy for states written with a batched value but not declared in
``state_out_axes``. The default ``'auto'`` infers the output axis from
the detected batch dimension.
the detected batch dimension. For an undeclared *read-modify-write* state,
``'auto'`` additionally feeds it per lane when its current leading size
matches the mapped size (re-checked every call, so warm and cold calls
agree), and warns once when it does not (the state is scattered, gaining a
new axis). See :class:`StatefulMapping` for the full description and how to
disambiguate a coincidental size match with
``state_in_axes``/``state_out_axes``.

Returns
-------
Expand Down Expand Up @@ -579,6 +602,12 @@ def _validate_leading_lengths(xs) -> int:
leaves = jax.tree.leaves(xs)
if not leaves:
raise ValueError("map requires at least one array input.")
for leaf in leaves:
if getattr(leaf, 'ndim', 1) == 0:
raise ValueError(
"map requires array inputs with a leading axis to map over; "
"got a 0-d (scalar) input."
)
length = leaves[0].shape[0]
for leaf in leaves:
if leaf.shape[0] != length:
Expand Down Expand Up @@ -811,7 +840,9 @@ def probe_hook(state):
# --- main pass: create batched states under the mapping primitive ----- #
state_box: Dict[Any, list] = {}

def init_fn(rng_keys):
def init_fn(rng_keys, _dummy=None):
# ``_dummy`` is an ignored mapped placeholder used only when there are no
# random states, so ``jax.pmap`` still has a mapped argument (#2).
for rng, key in zip(rng_states, rng_keys):
rng.restore_value(key)
with catch_new_states() as catcher:
Expand Down Expand Up @@ -844,14 +875,29 @@ def init_fn(rng_keys):

try:
with catch_new_states(state_tag):
mapped = primitive(
init_fn,
in_axes=(0 if len(rng_states) else None,),
out_axes=tuple(axes_order),
axis_size=axis_size,
axis_name=axis_name,
)
tuple_vals = mapped(rng_keys)
if len(rng_states):
mapped = primitive(
init_fn,
in_axes=(0,),
out_axes=tuple(axes_order),
axis_size=axis_size,
axis_name=axis_name,
)
tuple_vals = mapped(rng_keys)
else:
# No random states: ``rng_keys`` is empty, so the only argument has
# no mapped axis. ``jax.vmap`` tolerates that (``axis_size`` supplies
# the extent), but ``jax.pmap`` requires at least one mapped
# argument. Feed a dummy iota of length ``axis_size`` that
# ``init_fn`` ignores, uniformly for both primitives (#2).
mapped = primitive(
init_fn,
in_axes=(None, 0),
out_axes=tuple(axes_order),
axis_size=axis_size,
axis_name=axis_name,
)
tuple_vals = mapped(rng_keys, jnp.arange(axis_size))
finally:
# restore the global RNG once -- also on failure, so a crashed mapped
# pass cannot leave key tracers in the random states
Expand Down
147 changes: 147 additions & 0 deletions brainstate/transform/_mapping2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ def init_state(self, k):
self.assertEqual(m.w.value.shape, (self.n, 4))
self.assertFalse(jnp.allclose(m.w.value[0], m.w.value[1]))

def test_pmap2_new_states_no_rng(self):
# #2: init that uses no RandomState must still work under pmap (jax.pmap
# requires at least one mapped argument; the engine feeds a dummy iota).
# This is the pmap2_new_states docstring example.
class ParallelCounter(brainstate.nn.Module):
def init_state(self):
self.count = brainstate.ShortTermState(jnp.zeros(()))

m = ParallelCounter()
pmap2_new_states(m, init_kwargs={}, axis_size=self.n)
self.assertEqual(m.count.value.shape, (self.n,))


class TestIntegration(unittest.TestCase):
"""General integration scenarios."""
Expand Down Expand Up @@ -416,6 +428,81 @@ def test_static_argnums_none(self):
self.assertEqual(sm.static_argnums, ())


class TestStaticArgnumsExcludedFromMapping(unittest.TestCase):
"""Issue #8: a positional ``static_argnums`` arg must be closed over (jit
parity), not mapped -- mapping a non-array constant used to crash with
``'bool' object has no attribute 'ndim'``."""

def test_bool_static_positional_does_not_crash(self):
# The documented repro: flag is a Python bool used in control flow.
sm = StatefulMapping(lambda x, flag: x * (2. if flag else 1.),
in_axes=0, static_argnums=(1,))
out = sm(jnp.arange(3.), True)
self.assertTrue(jnp.allclose(out, jnp.arange(3.) * 2.))

def test_static_value_controls_branch(self):
# The static value is closed over per call: True and False select
# different Python branches (distinct cache keys, distinct plans).
f = lambda x, flag: x * (2. if flag else 1.)
sm = StatefulMapping(f, in_axes=0, static_argnums=(1,))
out_true = sm(jnp.arange(3.), True)
out_false = sm(jnp.arange(3.), False)
self.assertTrue(jnp.allclose(out_true, jnp.arange(3.) * 2.))
self.assertTrue(jnp.allclose(out_false, jnp.arange(3.) * 1.))

def test_matches_jax_jit_over_lanes(self):
# Static arg behaves like a compile-time constant: result equals the
# per-lane scalar computation.
def f(x, n):
acc = x
for _ in range(n): # n must be a Python int (static)
acc = acc + x
return acc

sm = StatefulMapping(f, in_axes=0, static_argnums=(1,))
out = sm(jnp.arange(4.), 3)
self.assertTrue(jnp.allclose(out, jnp.arange(4.) * 4.))

def test_static_positional_with_state_write(self):
# Closing over the static arg must not disturb state side effects.
counter = brainstate.ShortTermState(jnp.zeros(3))

def f(x, scale):
counter.value = counter.value + x * scale
return x

sm = StatefulMapping(f, in_axes=0,
state_in_axes={0: filter.OfType(brainstate.ShortTermState)},
state_out_axes={0: filter.OfType(brainstate.ShortTermState)},
static_argnums=(1,))
out = sm(jnp.ones(3), 5)
self.assertTrue(jnp.allclose(out, jnp.ones(3)))
self.assertTrue(jnp.allclose(counter.value, jnp.ones(3) * 5.))

def test_negative_static_argnum(self):
# Negative indices are normalized against the call's arg count.
sm = StatefulMapping(lambda x, flag: x * (2. if flag else 1.),
in_axes=0, static_argnums=(-1,))
out = sm(jnp.arange(3.), True)
self.assertTrue(jnp.allclose(out, jnp.arange(3.) * 2.))

def test_tuple_in_axes_drops_static_entry(self):
# When in_axes is a per-argument tuple, the static position's entry is
# dropped so the remaining dynamic args keep their axes.
def f(x, y, flag):
return (x + y) * (2. if flag else 1.)

sm = StatefulMapping(f, in_axes=(0, 0, None), static_argnums=(2,))
out = sm(jnp.arange(3.), jnp.ones(3), True)
self.assertTrue(jnp.allclose(out, (jnp.arange(3.) + 1.) * 2.))

def test_out_of_range_static_argnum_raises_value_error(self):
# A clear ValueError (jit parity), not an opaque IndexError.
sm = StatefulMapping(lambda x, flag: x, in_axes=0, static_argnums=(5,))
with self.assertRaises(ValueError):
sm(jnp.arange(3.), True)


class TestMap(unittest.TestCase):
def test_map_matches_vectorized(self):
xs = jnp.arange(6.0).reshape(6, 1)
Expand Down Expand Up @@ -525,6 +612,53 @@ def update(delta):
self.assertTrue(jnp.all(updated >= 1.0))


class TestAxisSizeValidation(unittest.TestCase):
"""#4: axis_size is validated against the inferred batch size."""

def test_axis_size_conflict_raises_no_rng(self):
@vmap2(in_axes=0, axis_size=5)
def f(x):
return x * 2

with self.assertRaises(ValueError) as cm:
f(jnp.arange(3.0))
msg = str(cm.exception)
self.assertIn('conflicts with the mapped', msg)
self.assertIn('5', msg)
self.assertIn('3', msg)

def test_axis_size_conflict_raises_with_rng(self):
# With an RNG draw the keys were previously split to the *inferred* size
# rather than axis_size; validation must catch the conflict first.
@vmap2(in_axes=0, axis_size=5)
def f(x):
return x + brainstate.random.randn()

with self.assertRaises(ValueError) as cm:
f(jnp.zeros(3))
msg = str(cm.exception)
self.assertIn('conflicts with the mapped', msg)
self.assertIn('5', msg)
self.assertIn('3', msg)

def test_axis_size_matching_ok(self):
@vmap2(in_axes=0, axis_size=3)
def f(x):
return x * 2

out = f(jnp.arange(3.0))
self.assertEqual(out.shape, (3,))

def test_axis_size_alone_with_broadcast_input_ok(self):
# No inferred size (all inputs broadcast) -> axis_size is used as-is.
@vmap2(in_axes=None, axis_size=4)
def f(x):
return x + brainstate.random.randn()

out = f(jnp.zeros(()))
self.assertEqual(out.shape, (4,))


class TestMapValidation(unittest.TestCase):
"""Tests for map() input-validation branches."""

Expand Down Expand Up @@ -689,6 +823,19 @@ def test_matching_lengths_returns_length(self):
length = _validate_leading_lengths((xs, xs))
self.assertEqual(length, 5)

def test_scalar_leaf_raises_value_error(self):
# #6: a 0-d (scalar) leaf has no leading axis to map over -> clear
# ValueError, not a cryptic IndexError from leaves[0].shape[0].
with self.assertRaises(ValueError) as cm:
_validate_leading_lengths((jnp.array(5.0),))
self.assertIn('0-d', str(cm.exception))

def test_map_scalar_input_raises_value_error(self):
# #6: the same through the public map() entry point.
with self.assertRaises(ValueError) as cm:
map(lambda x: x * 2, jnp.array(5.0))
self.assertIn('0-d', str(cm.exception))


class TestPmap2Decorator(unittest.TestCase):
"""Tests for pmap2() as a decorator (Missing fn path)."""
Expand Down
Loading