diff --git a/brainstate/transform/_mapping1.py b/brainstate/transform/_mapping1.py index 12825b8..d0fa398 100644 --- a/brainstate/transform/_mapping1.py +++ b/brainstate/transform/_mapping1.py @@ -114,6 +114,10 @@ def _vmap_transform( mapping_fn=functools.partial(jax.vmap, spmd_axis_name=spmd_axis_name), unexpected_out_state_mapping='raise', name='vmap', + # #5: speak the legacy vmap vocabulary (in_states/out_states, no policy + # knob) in the undeclared-write error instead of engine internals. + out_decl_name='out_states', + out_decl_extra='', ) @functools.wraps(f) diff --git a/brainstate/transform/_mapping1_test.py b/brainstate/transform/_mapping1_test.py index 38d5d56..5781fb5 100644 --- a/brainstate/transform/_mapping1_test.py +++ b/brainstate/transform/_mapping1_test.py @@ -509,6 +509,24 @@ def fn(x): with self.assertRaises(BatchAxisError): fn(xs) + def test_undeclared_write_error_uses_vmap_vocabulary(self): + # #5: the legacy vmap API uses in_states/out_states, not the engine's + # state_out_axes / unexpected_out_state_mapping. The undeclared-write + # error must speak the caller's vocabulary, not engine internals. + state = bst.ShortTermState(jnp.zeros(3)) + + @vmap(in_axes=0) + def fn(x): + state.value = state.value + x # batched write, not in out_states + return x + + with self.assertRaises(BatchAxisError) as cm: + fn(jnp.arange(3.0)) + msg = str(cm.exception) + self.assertIn('out_states', msg) + self.assertNotIn('state_out_axes', msg) + self.assertNotIn('unexpected_out_state_mapping', msg) + class TestVmapNewStates(unittest.TestCase): """Test vmap_new_states functionality.""" diff --git a/brainstate/transform/_mapping2.py b/brainstate/transform/_mapping2.py index abac99c..0cfd2ee 100644 --- a/brainstate/transform/_mapping2.py +++ b/brainstate/transform/_mapping2.py @@ -22,6 +22,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union, TypeVar import jax +import jax.numpy as jnp from brainstate._compatible_import import Device from brainstate._state import State, catch_new_states, StateTraceStack, TRACE_CONTEXT @@ -104,12 +105,28 @@ class StatefulMapping: same selector semantics as ``state_in_axes``. unexpected_out_state_mapping : {'auto', 'raise', 'warn', 'ignore'}, default 'auto' Policy for states written with a batched value but not covered by - ``state_out_axes``. ``'auto'`` scatters them at their detected axis, - ``'raise'`` raises a :class:`~brainstate._error.BatchAxisError`, - ``'warn'`` scatters them with a warning, and ``'ignore'`` scatters them + ``state_out_axes``. ``'auto'`` scatters them at their detected axis; + ``'raise'`` raises a :class:`~brainstate._error.BatchAxisError`; + ``'warn'`` scatters them with a warning; and ``'ignore'`` scatters them silently. + + Under ``'auto'``, an undeclared state that is *read and written* + (read-modify-write) is handled specially: on each call, if its current + leading size along the detected axis already equals the mapped size it is + fed **per lane** (a per-lane RMW buffer) instead of being broadcast and + scattered. This decision is re-made every call from the live value, so a + cached (warm) call behaves identically to a fresh (cold) one. Because the + choice rests on a leading-size match, a state whose leading dimension + *coincidentally* equals the mapped size is treated as per-lane even if you + meant it as a shared, broadcast value -- declare it explicitly via + ``state_in_axes``/``state_out_axes`` to remove the ambiguity. When such an + RMW state's leading size does *not* match the mapped size it is scattered + (gaining a new leading axis) and a one-time :class:`UserWarning` is + emitted. static_argnums : int or iterable of int, default () - Positional arguments treated as compile-time constants for caching. + Positional arguments treated as compile-time constants. The argument is + closed over (as in :func:`jax.jit`) and is neither traced nor mapped, so + its ``in_axes`` entry, if any, is ignored. static_argnames : str or iterable of str, default () Keyword arguments treated as compile-time constants for caching. axis_env : sequence, optional @@ -312,7 +329,13 @@ def vmap2( unexpected_out_state_mapping : {'auto', 'raise', 'warn', 'ignore'}, default 'auto' Policy for states written with a batched value but not declared in ``state_out_axes``. The default ``'auto'`` infers the output axis from - the detected batch dimension. + the detected batch dimension. For an undeclared *read-modify-write* state, + ``'auto'`` additionally feeds it per lane when its current leading size + matches the mapped size (re-checked every call, so warm and cold calls + agree), and warns once when it does not (the state is scattered, gaining a + new axis). See :class:`StatefulMapping` for the full description and how to + disambiguate a coincidental size match with + ``state_in_axes``/``state_out_axes``. Returns ------- @@ -579,6 +602,12 @@ def _validate_leading_lengths(xs) -> int: leaves = jax.tree.leaves(xs) if not leaves: raise ValueError("map requires at least one array input.") + for leaf in leaves: + if getattr(leaf, 'ndim', 1) == 0: + raise ValueError( + "map requires array inputs with a leading axis to map over; " + "got a 0-d (scalar) input." + ) length = leaves[0].shape[0] for leaf in leaves: if leaf.shape[0] != length: @@ -811,7 +840,9 @@ def probe_hook(state): # --- main pass: create batched states under the mapping primitive ----- # state_box: Dict[Any, list] = {} - def init_fn(rng_keys): + def init_fn(rng_keys, _dummy=None): + # ``_dummy`` is an ignored mapped placeholder used only when there are no + # random states, so ``jax.pmap`` still has a mapped argument (#2). for rng, key in zip(rng_states, rng_keys): rng.restore_value(key) with catch_new_states() as catcher: @@ -844,14 +875,29 @@ def init_fn(rng_keys): try: with catch_new_states(state_tag): - mapped = primitive( - init_fn, - in_axes=(0 if len(rng_states) else None,), - out_axes=tuple(axes_order), - axis_size=axis_size, - axis_name=axis_name, - ) - tuple_vals = mapped(rng_keys) + if len(rng_states): + mapped = primitive( + init_fn, + in_axes=(0,), + out_axes=tuple(axes_order), + axis_size=axis_size, + axis_name=axis_name, + ) + tuple_vals = mapped(rng_keys) + else: + # No random states: ``rng_keys`` is empty, so the only argument has + # no mapped axis. ``jax.vmap`` tolerates that (``axis_size`` supplies + # the extent), but ``jax.pmap`` requires at least one mapped + # argument. Feed a dummy iota of length ``axis_size`` that + # ``init_fn`` ignores, uniformly for both primitives (#2). + mapped = primitive( + init_fn, + in_axes=(None, 0), + out_axes=tuple(axes_order), + axis_size=axis_size, + axis_name=axis_name, + ) + tuple_vals = mapped(rng_keys, jnp.arange(axis_size)) finally: # restore the global RNG once -- also on failure, so a crashed mapped # pass cannot leave key tracers in the random states diff --git a/brainstate/transform/_mapping2_test.py b/brainstate/transform/_mapping2_test.py index a0993bb..3dd6330 100644 --- a/brainstate/transform/_mapping2_test.py +++ b/brainstate/transform/_mapping2_test.py @@ -223,6 +223,18 @@ def init_state(self, k): self.assertEqual(m.w.value.shape, (self.n, 4)) self.assertFalse(jnp.allclose(m.w.value[0], m.w.value[1])) + def test_pmap2_new_states_no_rng(self): + # #2: init that uses no RandomState must still work under pmap (jax.pmap + # requires at least one mapped argument; the engine feeds a dummy iota). + # This is the pmap2_new_states docstring example. + class ParallelCounter(brainstate.nn.Module): + def init_state(self): + self.count = brainstate.ShortTermState(jnp.zeros(())) + + m = ParallelCounter() + pmap2_new_states(m, init_kwargs={}, axis_size=self.n) + self.assertEqual(m.count.value.shape, (self.n,)) + class TestIntegration(unittest.TestCase): """General integration scenarios.""" @@ -416,6 +428,81 @@ def test_static_argnums_none(self): self.assertEqual(sm.static_argnums, ()) +class TestStaticArgnumsExcludedFromMapping(unittest.TestCase): + """Issue #8: a positional ``static_argnums`` arg must be closed over (jit + parity), not mapped -- mapping a non-array constant used to crash with + ``'bool' object has no attribute 'ndim'``.""" + + def test_bool_static_positional_does_not_crash(self): + # The documented repro: flag is a Python bool used in control flow. + sm = StatefulMapping(lambda x, flag: x * (2. if flag else 1.), + in_axes=0, static_argnums=(1,)) + out = sm(jnp.arange(3.), True) + self.assertTrue(jnp.allclose(out, jnp.arange(3.) * 2.)) + + def test_static_value_controls_branch(self): + # The static value is closed over per call: True and False select + # different Python branches (distinct cache keys, distinct plans). + f = lambda x, flag: x * (2. if flag else 1.) + sm = StatefulMapping(f, in_axes=0, static_argnums=(1,)) + out_true = sm(jnp.arange(3.), True) + out_false = sm(jnp.arange(3.), False) + self.assertTrue(jnp.allclose(out_true, jnp.arange(3.) * 2.)) + self.assertTrue(jnp.allclose(out_false, jnp.arange(3.) * 1.)) + + def test_matches_jax_jit_over_lanes(self): + # Static arg behaves like a compile-time constant: result equals the + # per-lane scalar computation. + def f(x, n): + acc = x + for _ in range(n): # n must be a Python int (static) + acc = acc + x + return acc + + sm = StatefulMapping(f, in_axes=0, static_argnums=(1,)) + out = sm(jnp.arange(4.), 3) + self.assertTrue(jnp.allclose(out, jnp.arange(4.) * 4.)) + + def test_static_positional_with_state_write(self): + # Closing over the static arg must not disturb state side effects. + counter = brainstate.ShortTermState(jnp.zeros(3)) + + def f(x, scale): + counter.value = counter.value + x * scale + return x + + sm = StatefulMapping(f, in_axes=0, + state_in_axes={0: filter.OfType(brainstate.ShortTermState)}, + state_out_axes={0: filter.OfType(brainstate.ShortTermState)}, + static_argnums=(1,)) + out = sm(jnp.ones(3), 5) + self.assertTrue(jnp.allclose(out, jnp.ones(3))) + self.assertTrue(jnp.allclose(counter.value, jnp.ones(3) * 5.)) + + def test_negative_static_argnum(self): + # Negative indices are normalized against the call's arg count. + sm = StatefulMapping(lambda x, flag: x * (2. if flag else 1.), + in_axes=0, static_argnums=(-1,)) + out = sm(jnp.arange(3.), True) + self.assertTrue(jnp.allclose(out, jnp.arange(3.) * 2.)) + + def test_tuple_in_axes_drops_static_entry(self): + # When in_axes is a per-argument tuple, the static position's entry is + # dropped so the remaining dynamic args keep their axes. + def f(x, y, flag): + return (x + y) * (2. if flag else 1.) + + sm = StatefulMapping(f, in_axes=(0, 0, None), static_argnums=(2,)) + out = sm(jnp.arange(3.), jnp.ones(3), True) + self.assertTrue(jnp.allclose(out, (jnp.arange(3.) + 1.) * 2.)) + + def test_out_of_range_static_argnum_raises_value_error(self): + # A clear ValueError (jit parity), not an opaque IndexError. + sm = StatefulMapping(lambda x, flag: x, in_axes=0, static_argnums=(5,)) + with self.assertRaises(ValueError): + sm(jnp.arange(3.), True) + + class TestMap(unittest.TestCase): def test_map_matches_vectorized(self): xs = jnp.arange(6.0).reshape(6, 1) @@ -525,6 +612,53 @@ def update(delta): self.assertTrue(jnp.all(updated >= 1.0)) +class TestAxisSizeValidation(unittest.TestCase): + """#4: axis_size is validated against the inferred batch size.""" + + def test_axis_size_conflict_raises_no_rng(self): + @vmap2(in_axes=0, axis_size=5) + def f(x): + return x * 2 + + with self.assertRaises(ValueError) as cm: + f(jnp.arange(3.0)) + msg = str(cm.exception) + self.assertIn('conflicts with the mapped', msg) + self.assertIn('5', msg) + self.assertIn('3', msg) + + def test_axis_size_conflict_raises_with_rng(self): + # With an RNG draw the keys were previously split to the *inferred* size + # rather than axis_size; validation must catch the conflict first. + @vmap2(in_axes=0, axis_size=5) + def f(x): + return x + brainstate.random.randn() + + with self.assertRaises(ValueError) as cm: + f(jnp.zeros(3)) + msg = str(cm.exception) + self.assertIn('conflicts with the mapped', msg) + self.assertIn('5', msg) + self.assertIn('3', msg) + + def test_axis_size_matching_ok(self): + @vmap2(in_axes=0, axis_size=3) + def f(x): + return x * 2 + + out = f(jnp.arange(3.0)) + self.assertEqual(out.shape, (3,)) + + def test_axis_size_alone_with_broadcast_input_ok(self): + # No inferred size (all inputs broadcast) -> axis_size is used as-is. + @vmap2(in_axes=None, axis_size=4) + def f(x): + return x + brainstate.random.randn() + + out = f(jnp.zeros(())) + self.assertEqual(out.shape, (4,)) + + class TestMapValidation(unittest.TestCase): """Tests for map() input-validation branches.""" @@ -689,6 +823,19 @@ def test_matching_lengths_returns_length(self): length = _validate_leading_lengths((xs, xs)) self.assertEqual(length, 5) + def test_scalar_leaf_raises_value_error(self): + # #6: a 0-d (scalar) leaf has no leading axis to map over -> clear + # ValueError, not a cryptic IndexError from leaves[0].shape[0]. + with self.assertRaises(ValueError) as cm: + _validate_leading_lengths((jnp.array(5.0),)) + self.assertIn('0-d', str(cm.exception)) + + def test_map_scalar_input_raises_value_error(self): + # #6: the same through the public map() entry point. + with self.assertRaises(ValueError) as cm: + map(lambda x: x * 2, jnp.array(5.0)) + self.assertIn('0-d', str(cm.exception)) + class TestPmap2Decorator(unittest.TestCase): """Tests for pmap2() as a decorator (Missing fn path).""" diff --git a/brainstate/transform/_mapping_core.py b/brainstate/transform/_mapping_core.py index eaf0464..4d63f91 100644 --- a/brainstate/transform/_mapping_core.py +++ b/brainstate/transform/_mapping_core.py @@ -138,6 +138,12 @@ _rand = None +# Audit #3: states already warned about an undeclared 'auto' read-modify-write +# whose leading dim does not match the batch size (so it is scattered, gaining an +# axis). A WeakSet keeps the warning one-time-per-state without pinning the state +# alive or leaking ids after garbage collection. +_AUTO_RMW_SCATTER_WARNED: "weakref.WeakSet" = weakref.WeakSet() + def _import_rand_state(): global _rand @@ -300,6 +306,16 @@ def _get_batch_size( ) return axis_size + # When axis_size is given explicitly it must agree with the size inferred from + # arguments/states; otherwise the mapping primitive fails late with an opaque + # XLA buffer-size error and the RNG split is sized wrong (audit #4). + if axis_size is not None and axis_size not in set(batch_sizes): + raise ValueError( + f"axis_size={axis_size} conflicts with the mapped axis size(s) " + f"{sorted(set(batch_sizes))} inferred from arguments/states. Either " + f"omit axis_size or make it match the mapped inputs." + ) + # Ensure all batch sizes are consistent. if len(set(batch_sizes)) > 1: raise ValueError(f"Inconsistent batch sizes found: {set(batch_sizes)}") @@ -553,6 +569,74 @@ def _split_kwargs(kwargs, static_argnames): return dyn, static +def _normalize_static_argnums(static_argnums, n_args): + """Resolve ``static_argnums`` to a frozenset of non-negative indices. + + Negative indices are counted from the end (as in :func:`jax.jit`); an index + outside ``range(n_args)`` raises a clear :class:`ValueError` instead of the + opaque ``IndexError`` that would surface later when the argument is sliced. + """ + resolved = set() + for i in static_argnums: + j = i + n_args if i < 0 else i + if not (0 <= j < n_args): + raise ValueError( + f"static_argnums={tuple(static_argnums)} refers to positional " + f"argument {i}, but the function was called with {n_args} " + f"positional argument(s)." + ) + resolved.add(j) + return frozenset(resolved) + + +def _close_static_argnums(f, in_axes, static_argnums, args): + """Bake the positional ``static_argnums`` into ``f`` (``jax.jit`` parity). + + Static positional arguments are compile-time constants: they are neither + traced nor mapped. We drop them from the argument list handed to the mapping + primitive and close over them in a thin wrapper that re-inserts them at their + original positions -- mirroring how ``static_argnames`` keyword arguments are + already handled. This avoids asking the primitive to map a non-array constant + (which fails with ``'' object has no attribute 'ndim'``). + + Parameters + ---------- + f : callable + The user function. + in_axes : int | None | tuple + Positional-argument batch-axis specification. + static_argnums : frozenset of int + Already-normalized (non-negative, in-range) static positional indices. + args : tuple + Full positional arguments for this call. + + Returns + ------- + tuple + ``(f_closed, dyn_args, dyn_in_axes)``. ``dyn_args`` excludes the static + positions; ``dyn_in_axes`` drops the matching entries when ``in_axes`` is + a per-argument tuple (an ``int`` / ``None`` axis is returned unchanged, + since it applies uniformly to whatever positional arguments remain). + """ + if not static_argnums: + return f, args, in_axes + n = len(args) + static_vals = {i: args[i] for i in static_argnums} + dyn_args = tuple(a for i, a in enumerate(args) if i not in static_argnums) + + @functools.wraps(f) + def f_closed(*dyn, **kwargs): + it = iter(dyn) + full = [static_vals[i] if i in static_vals else next(it) for i in range(n)] + return f(*full, **kwargs) + + if isinstance(in_axes, tuple) and len(in_axes) == n: + dyn_in_axes = tuple(ax for i, ax in enumerate(in_axes) if i not in static_argnums) + else: + dyn_in_axes = in_axes + return f_closed, dyn_args, dyn_in_axes + + def _strip_kwargs(dyn_kwargs): """Remove the leading mapped axis from each dynamic kwarg. @@ -621,15 +705,21 @@ class LiveStateMapPlan: oth_out_states : list of State Written states that are broadcast (not batched) on output and therefore restored from a single representative lane. + auto_in_candidate_ids : frozenset of int + Object ids of undeclared ``'auto'`` read-modify-write states whose + per-lane promotion is re-evaluated against the live value on every call + (audit #1). These states already live in :attr:`out_groups`; the ids are + matched there in :func:`_execute_plan`. """ - __slots__ = ('in_groups', 'rng_states', 'out_groups', 'oth_out_states') + __slots__ = ('in_groups', 'rng_states', 'out_groups', 'oth_out_states', 'auto_in_candidate_ids') - def __init__(self, in_groups, rng_states, out_groups, oth_out_states): + def __init__(self, in_groups, rng_states, out_groups, oth_out_states, auto_in_candidate_ids=frozenset()): self.in_groups = in_groups self.rng_states = rng_states self.out_groups = out_groups self.oth_out_states = oth_out_states + self.auto_in_candidate_ids = frozenset(auto_in_candidate_ids) class StateMapPlan: @@ -672,18 +762,23 @@ class StateMapPlan: restored from a single representative lane. """ - __slots__ = ('in_groups', 'rng_states', 'out_groups', 'oth_out_states') + __slots__ = ('in_groups', 'rng_states', 'out_groups', 'oth_out_states', 'auto_in_candidate_ids') - def __init__(self, in_groups, rng_states, out_groups, oth_out_states): + def __init__(self, in_groups, rng_states, out_groups, oth_out_states, auto_in_candidate_ids=frozenset()): self.in_groups = [(axis, [weakref.ref(st) for st in states]) for axis, states in in_groups] self.rng_states = [weakref.ref(st) for st in rng_states] self.out_groups = [(axis, [weakref.ref(st) for st in states]) for axis, states in out_groups] self.oth_out_states = [weakref.ref(st) for st in oth_out_states] + # Plain ids (no weakref): these states are a subset of out_groups, whose + # weakrefs already gate staleness; on materialize the same live objects + # are dereferenced, so their ids still match. + self.auto_in_candidate_ids = frozenset(auto_in_candidate_ids) @classmethod def from_live(cls, live: 'LiveStateMapPlan') -> 'StateMapPlan': """Snapshot a :class:`LiveStateMapPlan` as a weakref-backed plan for caching.""" - return cls(live.in_groups, live.rng_states, live.out_groups, live.oth_out_states) + return cls(live.in_groups, live.rng_states, live.out_groups, live.oth_out_states, + live.auto_in_candidate_ids) def materialize(self) -> Optional['LiveStateMapPlan']: """Resolve all weakrefs into a strong-ref :class:`LiveStateMapPlan`. @@ -722,7 +817,8 @@ def deref(refs): oth_out_states = deref(self.oth_out_states) if oth_out_states is None: return None - return LiveStateMapPlan(in_groups, rng_states, out_groups, oth_out_states) + return LiveStateMapPlan(in_groups, rng_states, out_groups, oth_out_states, + self.auto_in_candidate_ids) def _probe_axis_size(args, in_axes, axis_size, kwargs=None): @@ -740,6 +836,37 @@ def _probe_axis_size(args, in_axes, axis_size, kwargs=None): return 2 +class _ReadTrackingTrace(StateTraceStack): + """Probe trace that records *genuine* reads (value-getter accesses). + + :attr:`StateTraceStack.been_writen` cannot tell a read-modify-write apart + from a pure write: :meth:`StateTraceStack.write_its_value` calls + :meth:`StateTraceStack.read_its_value` internally the first time it sees a + state, so every written state looks "read". This subclass flags the window in + which that internal read happens, so only reads triggered by the value getter + (``state.value``) are counted as genuine. :func:`_build_plan` uses the result + to classify an undeclared ``'auto'`` batched write as a per-lane + read-modify-write (promotable) versus a pure output (scatter only). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.genuine_read_ids: set = set() + self._in_write_value = False + + def read_its_value(self, state) -> None: + if not self._in_write_value: + self.genuine_read_ids.add(id(state)) + super().read_its_value(state) + + def write_its_value(self, state) -> None: + self._in_write_value = True + try: + super().write_its_value(state) + finally: + self._in_write_value = False + + def _probe_states(f, args, kwargs, in_predicates, in_axes, name, axis_name=None, axis_size=None, static_argnames=()): """Enumerate touched states and classify batched inputs. @@ -760,6 +887,9 @@ def _probe_states(f, args, kwargs, in_predicates, in_axes, name, Random states encountered during the call. seen_in_ids : set[int] Object ids of states classified as batched inputs. + genuine_read_ids : set[int] + Object ids of states whose value getter was actually invoked (genuine + reads), used to tell read-modify-write states from pure writes. Notes ----- @@ -792,7 +922,7 @@ def hook(state): return jax.tree.map(lambda x: _deaxed_like(x, axis), state._value, is_leaf=u.math.is_quantity) return state._value - state_trace = StateTraceStack(name=name) + state_trace = _ReadTrackingTrace(name=name) state_trace.set_new_arg(hook) try: if axis_name is None: @@ -809,7 +939,7 @@ def _probe_body(*probe_args): jax.make_jaxpr(_probe_body, axis_env=[(axis_name, probe_size)])(*stripped) finally: state_trace.recovery_original_values() - return state_trace, dict(dim_to_in_states), rng_states, seen_in_ids + return state_trace, dict(dim_to_in_states), rng_states, seen_in_ids, state_trace.genuine_read_ids def _detect_out_dims(f, args, kwargs, in_groups, rng_states, write_states, @@ -881,11 +1011,19 @@ def _build_plan( in_predicates, out_predicates, in_axes, axis_size, axis_name, unexpected_out_state_mapping, name, static_argnames=(), + out_decl_name='state_out_axes', + out_decl_extra=" or set unexpected_out_state_mapping to 'auto', 'warn', or 'ignore'", ): - """Probe + discover + assemble a :class:`StateMapPlan`.""" + """Probe + discover + assemble a :class:`StateMapPlan`. + + ``out_decl_name`` / ``out_decl_extra`` let a caller phrase the undeclared-write + error in its own vocabulary -- the legacy ``vmap`` shim declares states via + ``out_states`` and has no policy knob, so it passes ``out_decl_name='out_states'`` + and an empty ``out_decl_extra`` (audit #5). + """ RandomState = _import_rand_state() - state_trace, dim_to_in_states, rng_states, seen_in_ids = _probe_states( + state_trace, dim_to_in_states, rng_states, seen_in_ids, genuine_read_ids = _probe_states( f, args, kwargs, in_predicates, in_axes, name, axis_name=axis_name, axis_size=axis_size, static_argnames=static_argnames, ) @@ -911,9 +1049,10 @@ def _build_plan( # assemble output groups in deterministic trace order out_axis_groups: Dict[int, List[State]] = defaultdict(list) oth_out_states: List[State] = [] - # B1: undeclared read-modify-write states that are already per-lane along the - # detected axis are promoted to batched inputs (see the 'auto' branch below). - promoted_in_states: Dict[int, List[State]] = defaultdict(list) + # #1: ids of undeclared read-modify-write states whose per-lane promotion is + # decided at execution time against the live value (see the 'auto' branch and + # _execute_plan). Deferring the decision makes warm calls match cold calls. + auto_in_candidate_ids: set = set() def _match_out_axis(st): for axis, pred in out_predicates.items(): @@ -959,8 +1098,7 @@ def _match_out_axis(st): st.raise_error_with_source_info( BatchAxisError( f"State\n {st} \nwas written with a batched value on axis {det} but is " - "not covered by state_out_axes. Declare it in state_out_axes or set " - "unexpected_out_state_mapping to 'auto', 'warn', or 'ignore'." + f"not covered by {out_decl_name}. Declare it in {out_decl_name}{out_decl_extra}." ) ) elif unexpected_out_state_mapping == 'warn': @@ -971,14 +1109,35 @@ def _match_out_axis(st): ) out_axis_groups[det].append(st) elif unexpected_out_state_mapping == 'auto': - # B1: if the state's prior value is already sized like the batch - # along the detected axis, this is a read-modify-write input the - # user did not declare. Treat it as a batched input+output so its - # value stays stable across calls, instead of gaining an extra - # axis every call. Otherwise scatter it (a genuine batched output). - if _leaf_axis_size(st.value, det) == batch_size: - promoted_in_states[det].append(st) + # #1: scatter the write at its detected axis. If the state is also + # genuinely read (read-modify-write), record it as a per-lane + # promotion candidate: _execute_plan re-checks, on every call, + # whether the live value is already batch-sized along this axis and + # if so feeds it per lane instead of broadcasting. Deferring this + # decision (instead of baking it into the cached plan) keeps warm + # calls in lock-step with cold calls, so the value no longer gains a + # new leading axis on every warm call. out_axis_groups[det].append(st) + if id(st) in genuine_read_ids: + auto_in_candidate_ids.add(id(st)) + # #3: when the live leading dim does not match the batch size, + # the state is scattered (gaining an axis), which is rarely the + # intent for a read-modify-write buffer. Surface the engine's + # otherwise-silent choice once per state. + cur = _leaf_axis_size(st.value, det) + if cur != batch_size and st not in _AUTO_RMW_SCATTER_WARNED: + _AUTO_RMW_SCATTER_WARNED.add(st) + warnings.warn( + f"State {st} is written with a batched value on axis {det} " + f"under the 'auto' policy but is not declared in " + f"state_in_axes/state_out_axes, and its current leading size " + f"({cur}) does not match the mapped size ({batch_size}); it is " + f"being scattered, which adds a new leading axis. If it is a " + f"per-lane read-modify-write buffer, pre-shape it to the mapped " + f"size or declare it via state_in_axes/state_out_axes to make " + f"the intent explicit.", + UserWarning, + ) elif unexpected_out_state_mapping == 'ignore': out_axis_groups[det].append(st) else: @@ -990,21 +1149,12 @@ def _match_out_axis(st): # broadcast write -- restore from a single lane oth_out_states.append(st) - # B1: fold promoted read-modify-write states into the input groups so the - # engine feeds them per-lane on input and scatters them on output (exactly as - # if the user had declared them in state_in_axes/state_out_axes at this axis). - if promoted_in_states: - merged_in: Dict[int, List[State]] = defaultdict(list) - for axis, states in in_groups: - merged_in[axis].extend(states) - for axis, states in promoted_in_states.items(): - merged_in[axis].extend(states) - in_groups = sorted(merged_in.items(), key=lambda kv: kv[0]) - out_groups = sorted(out_axis_groups.items(), key=lambda kv: kv[0]) # Return a live (strong-ref) plan for immediate execution; the caller - # snapshots it as a weakref-backed StateMapPlan for the warm cache (B3). - return LiveStateMapPlan(in_groups, rng_states, out_groups, oth_out_states) + # snapshots it as a weakref-backed StateMapPlan for the warm cache (B3). The + # auto read-modify-write candidates are promoted per-call in _execute_plan. + return LiveStateMapPlan(in_groups, rng_states, out_groups, oth_out_states, + frozenset(auto_in_candidate_ids)) class _PlanStaleError(Exception): @@ -1030,6 +1180,32 @@ def _execute_plan(plan: LiveStateMapPlan, f, args, kwargs, in_axes, out_axes, out_groups = plan.out_groups oth_out_states = plan.oth_out_states + # Batch size from declared inputs + args (before any per-call promotion). + batch_size = _get_batch_size(args, in_axes, dict(in_groups), axis_size, dyn_kwargs) + + # #1: re-evaluate undeclared 'auto' read-modify-write promotion against the + # LIVE state value on every call so warm calls match cold calls. A candidate + # (already present in out_groups) whose current leading size equals the batch + # is fed per lane this call instead of broadcast; otherwise it is left to + # scatter. Promotion only folds in states already sized to the batch, so + # ``batch_size`` is unaffected and need not be recomputed. + if plan.auto_in_candidate_ids: + already_in = {id(st) for _, states in in_groups for st in states} + promote: Dict[int, List[State]] = defaultdict(list) + for axis, states in out_groups: + for st in states: + if (id(st) in plan.auto_in_candidate_ids + and id(st) not in already_in + and _leaf_axis_size(st.value, axis) == batch_size): + promote[axis].append(st) + if promote: + merged: Dict[int, List[State]] = defaultdict(list) + for axis, states in in_groups: + merged[axis].extend(states) + for axis, states in promote.items(): + merged[axis].extend(states) + in_groups = sorted(merged.items(), key=lambda kv: kv[0]) + in_group_axes = [axis for axis, _ in in_groups] if len(in_group_axes) == 0: in_group_axes = 0 @@ -1037,7 +1213,6 @@ def _execute_plan(plan: LiveStateMapPlan, f, args, kwargs, in_axes, out_axes, if len(out_group_axes) == 0: out_group_axes = 0 - batch_size = _get_batch_size(args, in_axes, dict(in_groups), axis_size, dyn_kwargs) rng_keys, rng_backups = split_rng_keys(rng_states, batch_size) in_group_vals = [[st.value for st in states] for _, states in in_groups] @@ -1136,6 +1311,8 @@ def state_map_transform( static_argnums=(), static_argnames=(), name: Optional[str] = None, + out_decl_name: str = 'state_out_axes', + out_decl_extra: str = " or set unexpected_out_state_mapping to 'auto', 'warn', or 'ignore'", ): """Build a state-aware mapped version of ``f``. @@ -1164,8 +1341,12 @@ def state_map_transform( Policy for states written with a batched value but not declared in ``state_out_axes``. static_argnums, static_argnames : int/str or iterable - Positional/keyword arguments treated as compile-time constants when - building the per-signature cache key. + Positional/keyword arguments treated as compile-time constants + (``jax.jit`` parity). They key the per-signature plan cache and are + closed over -- neither traced nor mapped -- so a ``static_argnums`` + position is excluded from ``in_axes`` mapping entirely (its ``in_axes`` + entry, if any, is ignored). Negative ``static_argnums`` count from the + end; an out-of-range index raises :class:`ValueError`. name : str, optional Diagnostic name. @@ -1192,7 +1373,14 @@ def wrapped(*args, **kwargs): "to the positional arguments passed to the function, but got " f"{len(in_axes)} in_axes entries for {len(args)} positional arguments." ) - cache_key = get_arg_cache_key(static_argnums, static_argnames, args, kwargs) + # #8: positional ``static_argnums`` are compile-time constants (jit + # parity). Normalize/validate the indices, key the cache on their values + # (via get_arg_cache_key), then close over them so they are neither + # traced nor mapped -- the engine below sees a function of only the + # remaining dynamic positional arguments. + static_nums = _normalize_static_argnums(static_argnums, len(args)) + cache_key = get_arg_cache_key(static_nums, static_argnames, args, kwargs) + cf, cargs, caxes = _close_static_argnums(f, in_axes, static_nums, args) plan = cache.get(cache_key, None) # B3: a cached plan holds only weakrefs; materialize() returns None if # any of its states was garbage-collected (e.g. the caller rebuilt its @@ -1201,16 +1389,17 @@ def wrapped(*args, **kwargs): live = plan.materialize() if plan is not None else None if live is None: live = _build_plan( - f, args, kwargs, + cf, cargs, kwargs, in_predicates, out_predicates, - in_axes, axis_size, axis_name, + caxes, axis_size, axis_name, unexpected_out_state_mapping, name, static_argnames=static_argnames, + out_decl_name=out_decl_name, out_decl_extra=out_decl_extra, ) cache[cache_key] = StateMapPlan.from_live(live) try: return _execute_plan( - live, f, args, kwargs, in_axes, out_axes, + live, cf, cargs, kwargs, caxes, out_axes, axis_size, axis_name, mapping_fn, mapping_kwargs, static_argnames=static_argnames, ) @@ -1220,15 +1409,16 @@ def wrapped(*args, **kwargs): # plan, re-probe, and retry exactly once. A second divergence # (write set changing within a single call) is surfaced as-is. live = _build_plan( - f, args, kwargs, + cf, cargs, kwargs, in_predicates, out_predicates, - in_axes, axis_size, axis_name, + caxes, axis_size, axis_name, unexpected_out_state_mapping, name, static_argnames=static_argnames, + out_decl_name=out_decl_name, out_decl_extra=out_decl_extra, ) cache[cache_key] = StateMapPlan.from_live(live) return _execute_plan( - live, f, args, kwargs, in_axes, out_axes, + live, cf, cargs, kwargs, caxes, out_axes, axis_size, axis_name, mapping_fn, mapping_kwargs, static_argnames=static_argnames, ) diff --git a/brainstate/transform/_mapping_core_test.py b/brainstate/transform/_mapping_core_test.py index ce9847a..50d1cc6 100644 --- a/brainstate/transform/_mapping_core_test.py +++ b/brainstate/transform/_mapping_core_test.py @@ -379,6 +379,97 @@ def f(x): self.assertEqual(b.value.shape, (B, N)) self.assertTrue(jnp.allclose(b.value, jnp.full((B, N), 4.0))) + def test_rmw_dim_ne_batch_stable_across_warm_calls(self): + # #1: undeclared RMW whose leading dim (3) != batch (5). The first call + # scatters (3,) -> (5, 3); a warm (cached) call must NOT grow it again. + import warnings + w = brainstate.ShortTermState(jnp.zeros(3)) + + def f(x): + w.value = w.value + x + return w.value + + mapped = brainstate.transform.vmap2(f) # batch 5 != dim 3 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + mapped(jnp.arange(5.)) + s1 = w.value.shape + self.assertEqual(s1, (5, 3)) + mapped(jnp.arange(5.)) # warm call must not grow + self.assertEqual(w.value.shape, s1) # was (5,3) -> (5,5,3) before the fix + + def test_rmw_dim_ne_batch_warm_equals_cold(self): + # #1: a cached (warm) wrapper and a fresh (cold) wrapper must agree on the + # resulting shape across repeated calls (the "make warm == cold" fix). + import warnings + + def run(warm): + w = brainstate.ShortTermState(jnp.zeros(3)) + + def f(x): + w.value = w.value + x + return w.value + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if warm: + m = brainstate.transform.vmap2(f) + for _ in range(4): + m(jnp.arange(5.)) + else: + for _ in range(4): + brainstate.transform.vmap2(f)(jnp.arange(5.)) + return w.value + + warm_v, cold_v = run(warm=True), run(warm=False) + self.assertEqual(warm_v.shape, cold_v.shape) + self.assertTrue(jnp.allclose(warm_v, cold_v)) + + def test_auto_rmw_dim_mismatch_warns(self): + # #3: dim (3) != batch (5) on a genuinely-read (RMW) undeclared state is + # the surprising case -> a one-time UserWarning. + w = brainstate.ShortTermState(jnp.zeros(3)) + + def f(x): + w.value = w.value + x + return w.value + + mapped = brainstate.transform.vmap2(f) + with self.assertWarns(UserWarning): + mapped(jnp.arange(5.)) + + def test_auto_rmw_dim_match_no_warning(self): + # #3: pre-shaped RMW (dim == batch) is the recommended pattern -> silent. + import warnings + w = brainstate.ShortTermState(jnp.zeros(4)) + + def f(x): + w.value = w.value + x + return w.value + + mapped = brainstate.transform.vmap2(f) # batch 4 == dim 4 + with warnings.catch_warnings(record=True) as rec: + warnings.simplefilter("always") + mapped(jnp.ones(4)) + self.assertEqual([r for r in rec if issubclass(r.category, UserWarning)], []) + + def test_pure_output_dim_ne_batch_no_warning(self): + # #3: a pure-output (never-read) undeclared scatter is legitimate even + # when dim != batch -> silent (no false positive on correct code). + import warnings + w = brainstate.ShortTermState(jnp.zeros(3)) + + def f(x): + w.value = x * 2.0 # never reads w + return x + + mapped = brainstate.transform.vmap2(f) + with warnings.catch_warnings(record=True) as rec: + warnings.simplefilter("always") + mapped(jnp.arange(5.)) + self.assertEqual([r for r in rec if issubclass(r.category, UserWarning)], []) + self.assertEqual(w.value.shape, (5,)) + class TestStalePlanInvalidation(unittest.TestCase): """B3: a cached plan is invalidated when its States are recreated (GC'd). diff --git a/brainstate/transform/_shard_map.py b/brainstate/transform/_shard_map.py index 9d53d50..6e675d4 100644 --- a/brainstate/transform/_shard_map.py +++ b/brainstate/transform/_shard_map.py @@ -45,6 +45,53 @@ def _resolve_state_spec(state: State, table) -> PartitionSpec: return table # a single PartitionSpec applied to all states +def _augment_shard_state_error(exc, all_states, in_state_specs, local_arg_specs): + """Return a clearer *exception* for the common undeclared-state mismatch. + + Undeclared states default to **replication** -- placed at their full, global + shape on every shard -- while positional data is sharded into smaller + per-shard slices. Combining the two inside ``fun`` (e.g. ``buffer.value + + data``) raises an opaque broadcasting error that gives no hint about + ``state_in_specs`` (audit #7). When the failure looks like a shape mismatch + and a replicated touched state coexists with sharded data, return a new + exception (same type as ``exc`` when it can be rebuilt from a single message, + else :class:`RuntimeError`) carrying an actionable message; otherwise return + ``None`` so the original error propagates unchanged. + """ + msg = str(exc) + lowered = msg.lower() + looks_like_shape = ( + 'incompatible shapes' in lowered + or 'broadcast' in lowered + or ('shape' in lowered and 'mismatch' in lowered) + ) + if not looks_like_shape: + return None + replicated = [st for st, sp in zip(all_states, in_state_specs) if sp == PartitionSpec()] + sharded_data = any(sp != PartitionSpec() for sp in local_arg_specs) + if not (replicated and sharded_data): + return None + kinds = ', '.join(sorted({type(st).__name__ for st in replicated})) + hint = ( + f"{msg}\n\n" + "This shape mismatch usually means an undeclared State is replicated " + "(given its full, global shape on every shard) while the positional data " + "is sharded (a smaller per-shard slice), so combining them inside the " + "function fails. Undeclared states default to replication. If the state " + "should vary per shard, give it a matching partition via state_in_specs " + "(and state_out_specs for writes), e.g. " + "state_in_specs={the_state: P('x')}. " + f"Replicated touched state kind(s): {kinds}." + ) + # Preserve the original exception type when it can be rebuilt from a single + # string; otherwise fall back to RuntimeError so we never crash while trying + # to produce a friendlier message. + try: + return type(exc)(hint) + except Exception: + return RuntimeError(hint) + + @set_module_as("brainstate.transform") def shard_map( fun: Callable, @@ -104,6 +151,17 @@ def shard_map( ``State.restore_value``, runs ``fun``, and restores every state afterward (writes to their new values, reads to their originals). + States are **replicated by default** -- a state not covered by + ``state_in_specs`` / ``state_out_specs`` is placed at its full, global shape on + every shard. Combining such a replicated state with *sharded* positional data + (a smaller per-shard slice) inside ``fun`` -- e.g. ``buffer.value + data`` -- + raises a shape-mismatch error, because the operands have different per-shard + sizes. To keep a per-shard buffer, give the state a matching partition through + ``state_in_specs`` (and ``state_out_specs`` if it is written), as in the + per-shard buffer example above. Replicated states are appropriate for values + that are the same on every shard (scalars, shared parameters) or that are + reduced with a collective such as :func:`jax.lax.psum` before being written. + Examples -------- .. code-block:: python @@ -251,10 +309,15 @@ def pure(state_vals, mapped_args): sharded = jax_shard_map(pure, **sm_kwargs) try: out, write_vals = sharded(in_state_vals, sharded_args) - except Exception: + except Exception as e: # a failure mid-trace must not leave shard tracers in the states for st, ov in zip(all_states, orig_vals): st.restore_value(ov) + # #7: turn the opaque broadcasting error from an undeclared + # (replicated) state meeting sharded data into an actionable one. + new_exc = _augment_shard_state_error(e, all_states, in_state_specs, local_arg_specs) + if new_exc is not None: + raise new_exc from e raise # 6. Restore ALL states: writes -> new values, reads -> originals diff --git a/brainstate/transform/_shard_map_test.py b/brainstate/transform/_shard_map_test.py index 1292b61..5155536 100644 --- a/brainstate/transform/_shard_map_test.py +++ b/brainstate/transform/_shard_map_test.py @@ -55,6 +55,65 @@ def fun(data): self.assertTrue(jnp.allclose(out_state.value, data * 3.0)) self.assertTrue(jnp.allclose(w.value, 3.0)) # replicated read state unchanged + def test_undeclared_pershard_write_error_points_to_state_in_specs(self): + # #7: a replicated (undeclared) state read+written together with sharded + # data fails with a cryptic broadcast shape error. The wrapper should + # augment it to point at state_in_specs / state_out_specs. + if self.n < 2: + self.skipTest("Requires at least 2 devices") + buffer = brainstate.State(jnp.zeros(self.n * 2)) # replicated by default + + def accumulate(data): + buffer.value = buffer.value + data # full (n*2,) + per-shard slice + return data + + f = brainstate.transform.shard_map( + accumulate, self.mesh, in_specs=(P('x'),), out_specs=P('x'), + ) + with self.assertRaises(Exception) as cm: + f(jnp.arange(self.n * 2, dtype=jnp.float32)) + self.assertIn('state_in_specs', str(cm.exception)) + + # The error-augmenter is a pure function; unit-test it directly so its + # behavior is verified on any number of devices (the integration test above + # needs >= 2 devices and skips on single-device CI). + def test_augment_returns_actionable_exception_for_shape_mismatch(self): + from brainstate.transform._shard_map import _augment_shard_state_error + st = brainstate.State(jnp.zeros(4)) + new_exc = _augment_shard_state_error( + ValueError("Incompatible shapes for broadcasting: (4,) and (2,)"), + [st], [P()], [P('x')], + ) + self.assertIsInstance(new_exc, ValueError) # original type preserved + self.assertIn('state_in_specs', str(new_exc)) + self.assertIn('State', str(new_exc)) # replicated kind named + + def test_augment_ignores_non_shape_errors(self): + from brainstate.transform._shard_map import _augment_shard_state_error + st = brainstate.State(jnp.zeros(4)) + self.assertIsNone(_augment_shard_state_error( + ValueError("some unrelated runtime failure"), [st], [P()], [P('x')])) + + def test_augment_ignores_when_no_sharded_data(self): + from brainstate.transform._shard_map import _augment_shard_state_error + st = brainstate.State(jnp.zeros(4)) + # replicated state but every arg is replicated too -> not the diagnosed case + self.assertIsNone(_augment_shard_state_error( + ValueError("incompatible shapes"), [st], [P()], [P()])) + + def test_augment_falls_back_to_runtimeerror_when_type_not_rebuildable(self): + from brainstate.transform._shard_map import _augment_shard_state_error + + class WeirdError(Exception): + def __init__(self, a, b): # cannot be rebuilt from a single string + super().__init__(a) + + st = brainstate.State(jnp.zeros(4)) + new_exc = _augment_shard_state_error( + WeirdError("incompatible shapes", "extra"), [st], [P()], [P('x')]) + self.assertIsInstance(new_exc, RuntimeError) + self.assertIn('state_in_specs', str(new_exc)) + def test_no_state_function(self): def fun(data): return data + 1.0