Skip to content

fix(transform): resolve 8 bugs in the vmap/pmap/shard_map mapping engine#216

Merged
chaoming0625 merged 2 commits into
mainfrom
worktree-mapping-transform-fixes
Jun 13, 2026
Merged

fix(transform): resolve 8 bugs in the vmap/pmap/shard_map mapping engine#216
chaoming0625 merged 2 commits into
mainfrom
worktree-mapping-transform-fixes

Conversation

@chaoming0625

Copy link
Copy Markdown
Collaborator

Summary

An audit of the brainstate.transform mapping APIs (vmap, vmap2, pmap2, map, shard_map, vmap2_new_states / pmap2_new_states) surfaced 8 verified bugs. Each is fixed here with a reproducing test (RED→GREEN); the full brainstate/transform/ suite passes (1156 passed, 1 skipped).

# Severity API Bug
1 High 'auto' policy undeclared read-modify-write state with leading dim ≠ batch grew a new leading axis on every warm call
2 High pmap2_new_states failed when init used no RandomState (pmap needs a mapped arg)
3 Medium 'auto' policy silently flipped RMW-vs-scatter on a coincidental dim match
4 Medium vmap family axis_size never validated against the inferred batch size
5 Low legacy vmap undeclared-write error referenced engine internals
6 Low map 0-d (scalar) input → cryptic IndexError
7 Low-Med shard_map undeclared per-shard write → cryptic shape-mismatch error
8 Low StatefulMapping(static_argnums=...) static positional arg still mapped → 'bool' has no attribute 'ndim'

Details

  • Update documtation #1 — per-lane promotion of an undeclared RMW state is now decided at execution time from the live value on every call, so warm (cached) calls match cold calls and the value no longer grows an axis each call.
  • add brainstate.nn.Embedding; support reset_state() #2 — when there are no random states, a dummy iota of length axis_size is fed (and ignored) so jax.pmap still has a mapped argument.
  • rewrite _make_jaxpr() function to be compatible with jax==0.4.29 #3 — a new _ReadTrackingTrace distinguishes a genuine state.value read from the internal read inside write_its_value; a genuinely-read undeclared state whose leading dim ≠ batch now emits a one-time UserWarning instead of silently scattering.
  • remove math module #4axis_size that conflicts with the inferred batch size now raises a clear ValueError instead of a late, opaque XLA buffer error.
  • Linear interpolation of the delay #5 — the legacy vmap undeclared-write error now speaks the caller's vocabulary (out_states) via out_decl_name / out_decl_extra hooks, not state_out_axes / unexpected_out_state_mapping.
  • Simplify the linear interpolation to accelerate brainstate.Delay #6map over a 0-d input raises a clear ValueError naming the missing leading axis.
  • Fix jit error #7 — a replicated (undeclared) state combined with sharded data now produces an actionable error pointing at state_in_specs / state_out_specs; the shard_map Notes document the replication default.
  • Updates #8 — positional static_argnums are closed over (jax.jit parity): neither traced nor mapped, with negative-index normalization and an out-of-range ValueError.

Testing

  • New reproducing tests for every issue across _mapping_core_test.py, _mapping1_test.py, _mapping2_test.py, _shard_map_test.py.
  • python -m pytest brainstate/transform/1156 passed, 1 skipped, 0 failed.
  • Doctests on the touched modules pass.

Audit of brainstate.transform mapping APIs surfaced 8 verified issues; each
is fixed with a reproducing test (RED->GREEN).

#1 (High) 'auto' undeclared read-modify-write state with leading dim != batch
  grew a new leading axis on every warm (cached) call. The per-lane promotion
  decision is now deferred to execution time and re-made from the live value on
  every call, so warm calls match cold calls and the shape stays stable.

#2 (High) pmap2_new_states failed when the init used no RandomState, because
  jax.pmap requires at least one mapped argument. A dummy iota of length
  axis_size is now fed (and ignored) when there are no random states.

#3 (Medium) 'auto' silently flipped read-modify-write vs scatter on a
  coincidental leading-dim match. A genuinely-read undeclared state whose dim
  does not match the batch now emits a one-time UserWarning; a _ReadTrackingTrace
  distinguishes genuine reads from the internal read inside write_its_value.

#4 (Medium) axis_size was never validated against the inferred batch size; a
  mismatch now raises a clear ValueError instead of an opaque late XLA error.

#5 (Low) the legacy vmap undeclared-write error referenced engine internals
  (state_out_axes / unexpected_out_state_mapping). It now speaks the legacy
  vocabulary (out_states) via out_decl_name / out_decl_extra hooks.

#6 (Low) map over a 0-d (scalar) input raised a cryptic IndexError; it now
  raises a clear ValueError naming the missing leading axis.

#7 (Low-Med) shard_map undeclared (replicated) per-shard write meeting sharded
  data raised an opaque broadcast error; it is now augmented to point at
  state_in_specs / state_out_specs, and the Notes document the replication
  default.

#8 (Low) StatefulMapping(static_argnums=...) did not exclude the positional arg
  from mapping ('bool' has no attribute 'ndim'). Static positional args are now
  closed over (jax.jit parity): neither traced nor mapped, with negative-index
  normalization and an out-of-range ValueError.

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @chaoming0625, you have reached your weekly rate limit of 500000 diff characters.

Please try again later or upgrade to continue using Sourcery

@codecov

codecov Bot commented Jun 13, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 98.13084% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
brainstate/transform/_shard_map.py 90.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

…e mesh

CI runs the full suite on a single CPU device, so the shard_map error-augment
path (#7) and the tuple-in_axes static-arg branch (#8) were exercised only by
tests that skip on one device, dropping patch coverage below the 90% gate.

- Refactor _augment_shard_state_error to return the rebuilt exception (folding
  the type(e)(hint) / RuntimeError fallback into the pure helper) so the whole
  augmentation path is unit-testable off-device; simplify shard_map's except.
- Add device-independent unit tests for the helper (shape-mismatch success,
  non-shape passthrough, no-sharded-data passthrough, RuntimeError fallback).
- Add a tuple-in_axes StatefulMapping(static_argnums=...) test.

Patch coverage on a single device rises from ~89% to ~99%; the only remaining
uncovered line is the final `raise ... from e`, which needs a real 2-device
shape mismatch and is covered by the existing integration test on multi-device
machines. Behavior is unchanged.
@chaoming0625 chaoming0625 merged commit d696261 into main Jun 13, 2026
7 checks passed
@chaoming0625 chaoming0625 deleted the worktree-mapping-transform-fixes branch June 13, 2026 07:52
chaoming0625 added a commit that referenced this pull request Jun 13, 2026
…217)

Resolves every issue catalogued in the cross-module audit (dev/issues.md):
Critical C1, High H1-H21, Medium M1-M46, Low L1-L29, the needs-human-judgment
items NJ1-NJ4, and the 110-entry appendix of unverified findings. Each fix is
paired with a behavioral regression test.

Highlights:
- Runtime validation raises TypeError/ValueError (not assert, which is stripped
  under python -O) across nn, random, transform, util, graph, interop and core.
- Fixes target stable public JAX APIs; verified green on the full CI JAX matrix
  (0.7.0, 0.8.0, 0.9.0, latest).
- Genuinely ambiguous behavioral contracts (the NJ items) are resolved by
  documenting the existing behavior rather than silently changing it.
- Reconciled with the independently-merged mapping-engine rewrite (#216): where
  the audit overlapped, #216's reviewed rewrite was kept; only additive,
  non-overlapping items were re-applied.

Test suite: 5296 passed, 23 skipped. mypy clean. Patch coverage 100% (lines).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant