Skip to content

fix(nn): resolve bugs and edge cases found across the nn module audit#215

Merged
chaoming0625 merged 1 commit into
mainfrom
worktree-nn-module-audit
Jun 11, 2026
Merged

fix(nn): resolve bugs and edge cases found across the nn module audit#215
chaoming0625 merged 1 commit into
mainfrom
worktree-nn-module-audit

Conversation

@chaoming0625

Copy link
Copy Markdown
Collaborator

Summary

A systematic audit of brainstate.nn surfaced a range of correctness bugs and
unhandled edge cases. This PR fixes all confirmed findings. Every fix is covered
by a regression test, and several previously-skipped "known bug" tests are now
un-skipped and passing.

Full brainstate/nn/ suite: 1856 passed, 0 failed, 5 skipped.

Fixes by area

Dropout / elementwise / activations

  • AlphaDropout / FeatureAlphaDropout: correct self-normalizing affine constants.
  • Dropout2d/Dropout3d unbatched minimal-dim detection (per-element mask independence).
  • Softmin/Softmax/LogSoftmax default dim → last axis; rrelu unit/int handling; soft_shrink zero-branch unit; channel-last docstrings.

Linear / init / utils

  • ScaledWSLinear mask/weight/bias shapes; AllToAll out>in padding.
  • TruncatedNormal default bounds; clip_grad_norm unitless-gradient note.

Metrics / exp_euler

  • Precision/Recall 'weighted' average + average validation; Welford int counter.
  • exp_euler diagonal-Jacobian docstring clarification.

Transforms / param / hidata

  • Softplus/NegSoftplus/Negative/Ordered: saturation-free forward and unit-safe stable inverse.
  • Sigmoid/Affine log_abs_det_jacobian unit handling and per-batch shape; Affine zero-scale check on mantissa; Exp/Log/Positive saturation docs.
  • HiData.clone/add/pop/replace preserve name.

Module / common / collective_ops

  • assign_state_values: pytree/Quantity values via tree.map; accept dotted-string and tuple keys.
  • _filter_states dict branch iterates .items().
  • vmap_call_all_fns rebuilt on vmap_new_states (fixes a BatchTracer leak).
  • Map.update no longer forwards spmd_axis_name to pmap2.
  • Sequential empty-slice returns an empty Sequential.
  • in_size/out_size setters accept numpy scalars and 0-d arrays uniformly.

Delay / dynamics / event_fixedprob

  • Delay.max_time grows monotonically across registrations.
  • take_aware_unit retrieval no longer crashes / double-applies the unit.
  • update_every made functional via a monotonic per-call write pointer, with correct slot-spacing on time-based retrieval.
  • FixedNumConn afferent-ratio mask respects seed; broken efferent_target='pre' path guarded with a clear NotImplementedError.

Testing

  • TDD throughout (red → green); regression test added per finding.
  • python -m pytest brainstate/nn/ -q → 1856 passed, 0 failed, 5 skipped.

🤖 Generated with Claude Code

Systematic audit of brainstate.nn. Each fix is covered by a regression test
(several previously-skipped "known bug" tests are now un-skipped and pass).

Dropout/elementwise/activations:
- AlphaDropout/FeatureAlphaDropout: correct self-normalizing affine constants.
- Dropout2d/3d unbatched minimal-dim detection (per-element mask independence).
- Softmin/Softmax/LogSoftmax default dim -> last axis; rrelu unit/int handling;
  soft_shrink zero-branch unit; channel-last docstrings.

Linear/init/utils:
- ScaledWSLinear mask/weight/bias shapes; AllToAll out>in padding.
- TruncatedNormal default bounds; clip_grad_norm unitless-gradient note.

Metrics/exp_euler:
- Precision/Recall 'weighted' average + average validation; Welford int counter.
- exp_euler diagonal-Jacobian docstring clarification.

Transforms/param/hidata:
- Softplus/NegSoftplus/Negative/Ordered: stable saturation-free forward and
  unit-safe stable inverse; Sigmoid/Affine log_abs_det_jacobian unit handling
  and per-batch shape; Affine zero-scale check on mantissa; Exp/Log/Positive
  saturation docs; HiData.clone/add/pop/replace preserve name.

Module/common/collective_ops:
- assign_state_values: pytree/Quantity values via tree.map; accept dotted-string
  and tuple keys. _filter_states dict branch iterates items(). vmap_call_all_fns
  rebuilt on vmap_new_states (fixes BatchTracer leak). Map.update no longer
  forwards spmd_axis_name to pmap2. Sequential empty-slice returns empty
  Sequential. in/out_size setters accept numpy scalars and 0-d arrays uniformly.

Delay/dynamics/event_fixedprob:
- Delay.max_time grows monotonically across registrations; take_aware_unit
  retrieval no longer crashes / double-applies unit; update_every made functional
  via a monotonic per-call write pointer and correct slot-spacing on time
  retrieval. FixedNumConn afferent_ratio mask respects seed; broken
  efferent_target='pre' path guarded with a clear NotImplementedError.
@chaoming0625 chaoming0625 merged commit 95b4e24 into main Jun 11, 2026
5 checks passed
@chaoming0625 chaoming0625 deleted the worktree-nn-module-audit branch June 11, 2026 15:21
@codecov

codecov Bot commented Jun 11, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 95.76720% with 8 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
brainstate/nn/_transform.py 85.00% 2 Missing and 1 partial ⚠️
brainstate/nn/_conv.py 93.10% 1 Missing and 1 partial ⚠️
brainstate/nn/_hidata.py 75.00% 1 Missing ⚠️
brainstate/nn/_metrics.py 93.33% 1 Missing ⚠️
brainstate/nn/_poolings.py 97.05% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @chaoming0625, you have reached your weekly rate limit of 500000 diff characters.

Please try again later or upgrade to continue using Sourcery

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant