fix(transform): resolve 8 bugs in the vmap/pmap/shard_map mapping engine#216
Merged
Conversation
Audit of brainstate.transform mapping APIs surfaced 8 verified issues; each is fixed with a reproducing test (RED->GREEN). #1 (High) 'auto' undeclared read-modify-write state with leading dim != batch grew a new leading axis on every warm (cached) call. The per-lane promotion decision is now deferred to execution time and re-made from the live value on every call, so warm calls match cold calls and the shape stays stable. #2 (High) pmap2_new_states failed when the init used no RandomState, because jax.pmap requires at least one mapped argument. A dummy iota of length axis_size is now fed (and ignored) when there are no random states. #3 (Medium) 'auto' silently flipped read-modify-write vs scatter on a coincidental leading-dim match. A genuinely-read undeclared state whose dim does not match the batch now emits a one-time UserWarning; a _ReadTrackingTrace distinguishes genuine reads from the internal read inside write_its_value. #4 (Medium) axis_size was never validated against the inferred batch size; a mismatch now raises a clear ValueError instead of an opaque late XLA error. #5 (Low) the legacy vmap undeclared-write error referenced engine internals (state_out_axes / unexpected_out_state_mapping). It now speaks the legacy vocabulary (out_states) via out_decl_name / out_decl_extra hooks. #6 (Low) map over a 0-d (scalar) input raised a cryptic IndexError; it now raises a clear ValueError naming the missing leading axis. #7 (Low-Med) shard_map undeclared (replicated) per-shard write meeting sharded data raised an opaque broadcast error; it is now augmented to point at state_in_specs / state_out_specs, and the Notes document the replication default. #8 (Low) StatefulMapping(static_argnums=...) did not exclude the positional arg from mapping ('bool' has no attribute 'ndim'). Static positional args are now closed over (jax.jit parity): neither traced nor mapped, with negative-index normalization and an out-of-range ValueError.
Contributor
There was a problem hiding this comment.
Sorry @chaoming0625, you have reached your weekly rate limit of 500000 diff characters.
Please try again later or upgrade to continue using Sourcery
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
…e mesh CI runs the full suite on a single CPU device, so the shard_map error-augment path (#7) and the tuple-in_axes static-arg branch (#8) were exercised only by tests that skip on one device, dropping patch coverage below the 90% gate. - Refactor _augment_shard_state_error to return the rebuilt exception (folding the type(e)(hint) / RuntimeError fallback into the pure helper) so the whole augmentation path is unit-testable off-device; simplify shard_map's except. - Add device-independent unit tests for the helper (shape-mismatch success, non-shape passthrough, no-sharded-data passthrough, RuntimeError fallback). - Add a tuple-in_axes StatefulMapping(static_argnums=...) test. Patch coverage on a single device rises from ~89% to ~99%; the only remaining uncovered line is the final `raise ... from e`, which needs a real 2-device shape mismatch and is covered by the existing integration test on multi-device machines. Behavior is unchanged.
chaoming0625
added a commit
that referenced
this pull request
Jun 13, 2026
…217) Resolves every issue catalogued in the cross-module audit (dev/issues.md): Critical C1, High H1-H21, Medium M1-M46, Low L1-L29, the needs-human-judgment items NJ1-NJ4, and the 110-entry appendix of unverified findings. Each fix is paired with a behavioral regression test. Highlights: - Runtime validation raises TypeError/ValueError (not assert, which is stripped under python -O) across nn, random, transform, util, graph, interop and core. - Fixes target stable public JAX APIs; verified green on the full CI JAX matrix (0.7.0, 0.8.0, 0.9.0, latest). - Genuinely ambiguous behavioral contracts (the NJ items) are resolved by documenting the existing behavior rather than silently changing it. - Reconciled with the independently-merged mapping-engine rewrite (#216): where the audit overlapped, #216's reviewed rewrite was kept; only additive, non-overlapping items were re-applied. Test suite: 5296 passed, 23 skipped. mypy clean. Patch coverage 100% (lines).
Merged
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
An audit of the
brainstate.transformmapping APIs (vmap,vmap2,pmap2,map,shard_map,vmap2_new_states/pmap2_new_states) surfaced 8 verified bugs. Each is fixed here with a reproducing test (RED→GREEN); the fullbrainstate/transform/suite passes (1156 passed, 1 skipped).'auto'policypmap2_new_statesRandomState(pmap needs a mapped arg)'auto'policyaxis_sizenever validated against the inferred batch sizevmapmapIndexErrorshard_mapStatefulMapping(static_argnums=...)'bool' has no attribute 'ndim'Details
brainstate.nn.Embedding; supportreset_state()#2 — when there are no random states, a dummy iota of lengthaxis_sizeis fed (and ignored) sojax.pmapstill has a mapped argument._make_jaxpr()function to be compatible withjax==0.4.29#3 — a new_ReadTrackingTracedistinguishes a genuinestate.valueread from the internal read insidewrite_its_value; a genuinely-read undeclared state whose leading dim ≠ batch now emits a one-timeUserWarninginstead of silently scattering.mathmodule #4 —axis_sizethat conflicts with the inferred batch size now raises a clearValueErrorinstead of a late, opaque XLA buffer error.vmapundeclared-write error now speaks the caller's vocabulary (out_states) viaout_decl_name/out_decl_extrahooks, notstate_out_axes/unexpected_out_state_mapping.brainstate.Delay#6 —mapover a 0-d input raises a clearValueErrornaming the missing leading axis.state_in_specs/state_out_specs; theshard_mapNotes document the replication default.static_argnumsare closed over (jax.jit parity): neither traced nor mapped, with negative-index normalization and an out-of-rangeValueError.Testing
_mapping_core_test.py,_mapping1_test.py,_mapping2_test.py,_shard_map_test.py.python -m pytest brainstate/transform/→ 1156 passed, 1 skipped, 0 failed.