fix(nn): resolve bugs and edge cases found across the nn module audit#215
Merged
Conversation
Systematic audit of brainstate.nn. Each fix is covered by a regression test (several previously-skipped "known bug" tests are now un-skipped and pass). Dropout/elementwise/activations: - AlphaDropout/FeatureAlphaDropout: correct self-normalizing affine constants. - Dropout2d/3d unbatched minimal-dim detection (per-element mask independence). - Softmin/Softmax/LogSoftmax default dim -> last axis; rrelu unit/int handling; soft_shrink zero-branch unit; channel-last docstrings. Linear/init/utils: - ScaledWSLinear mask/weight/bias shapes; AllToAll out>in padding. - TruncatedNormal default bounds; clip_grad_norm unitless-gradient note. Metrics/exp_euler: - Precision/Recall 'weighted' average + average validation; Welford int counter. - exp_euler diagonal-Jacobian docstring clarification. Transforms/param/hidata: - Softplus/NegSoftplus/Negative/Ordered: stable saturation-free forward and unit-safe stable inverse; Sigmoid/Affine log_abs_det_jacobian unit handling and per-batch shape; Affine zero-scale check on mantissa; Exp/Log/Positive saturation docs; HiData.clone/add/pop/replace preserve name. Module/common/collective_ops: - assign_state_values: pytree/Quantity values via tree.map; accept dotted-string and tuple keys. _filter_states dict branch iterates items(). vmap_call_all_fns rebuilt on vmap_new_states (fixes BatchTracer leak). Map.update no longer forwards spmd_axis_name to pmap2. Sequential empty-slice returns empty Sequential. in/out_size setters accept numpy scalars and 0-d arrays uniformly. Delay/dynamics/event_fixedprob: - Delay.max_time grows monotonically across registrations; take_aware_unit retrieval no longer crashes / double-applies unit; update_every made functional via a monotonic per-call write pointer and correct slot-spacing on time retrieval. FixedNumConn afferent_ratio mask respects seed; broken efferent_target='pre' path guarded with a clear NotImplementedError.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Contributor
There was a problem hiding this comment.
Sorry @chaoming0625, you have reached your weekly rate limit of 500000 diff characters.
Please try again later or upgrade to continue using Sourcery
Merged
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
A systematic audit of
brainstate.nnsurfaced a range of correctness bugs andunhandled edge cases. This PR fixes all confirmed findings. Every fix is covered
by a regression test, and several previously-skipped "known bug" tests are now
un-skipped and passing.
Full
brainstate/nn/suite: 1856 passed, 0 failed, 5 skipped.Fixes by area
Dropout / elementwise / activations
AlphaDropout/FeatureAlphaDropout: correct self-normalizing affine constants.Dropout2d/Dropout3dunbatched minimal-dim detection (per-element mask independence).Softmin/Softmax/LogSoftmaxdefaultdim→ last axis;rreluunit/int handling;soft_shrinkzero-branch unit; channel-last docstrings.Linear / init / utils
ScaledWSLinearmask/weight/bias shapes;AllToAllout>in padding.TruncatedNormaldefault bounds;clip_grad_normunitless-gradient note.Metrics / exp_euler
Precision/Recall'weighted'average +averagevalidation; Welford int counter.exp_eulerdiagonal-Jacobian docstring clarification.Transforms / param / hidata
Sigmoid/Affinelog_abs_det_jacobianunit handling and per-batch shape;Affinezero-scale check on mantissa; Exp/Log/Positive saturation docs.HiData.clone/add/pop/replacepreservename.Module / common / collective_ops
assign_state_values: pytree/Quantityvalues viatree.map; accept dotted-string and tuple keys._filter_statesdict branch iterates.items().vmap_call_all_fnsrebuilt onvmap_new_states(fixes aBatchTracerleak).Map.updateno longer forwardsspmd_axis_nametopmap2.Sequentialempty-slice returns an emptySequential.in_size/out_sizesetters accept numpy scalars and 0-d arrays uniformly.Delay / dynamics / event_fixedprob
Delay.max_timegrows monotonically across registrations.take_aware_unitretrieval no longer crashes / double-applies the unit.update_everymade functional via a monotonic per-call write pointer, with correct slot-spacing on time-based retrieval.FixedNumConnafferent-ratio mask respectsseed; brokenefferent_target='pre'path guarded with a clearNotImplementedError.Testing
python -m pytest brainstate/nn/ -q→ 1856 passed, 0 failed, 5 skipped.🤖 Generated with Claude Code