Skip to content

fix(trainer,visualize): correct six latent API bugs (#92-#97)#98

Merged
chaoming0625 merged 1 commit into
mainfrom
fix-latent-bugs
Jun 9, 2026
Merged

fix(trainer,visualize): correct six latent API bugs (#92-#97)#98
chaoming0625 merged 1 commit into
mainfrom
fix-latent-bugs

Conversation

@chaoming0625

@chaoming0625 chaoming0625 commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Summary

Fixes six latent API bugs surfaced while raising braintools test coverage above 90% (#90). Each bug is a hard crash (TypeError/AttributeError/wrong-API) on a documented public call path; all are covered by the regression tests added here.

Issue Symptom Fix
#92 LightningModule.deviceTypeError: 'set' object is not subscriptable jax.Array.devices() returns a set; take next(iter(...))
#93 visualize.animate_2D → pcolor called on a 1-D array reshape values to the (height, width) grid before FuncAnimation
#94 visualize.correlation_matrix(method='kendall')kendalltau rejects a 2-D matrix build the correlation matrix pairwise over feature columns
#95 trainer.ModelCheckpoint cannot save — uses msgpack_from_state_dict (a restorer) as a writer write with file.msgpack_save
#96 visualize.remove_axisAttributeError: 'Axes' object has no attribute 'spine' ax.spineax.spines
#97 create_neural_colormap/brain_colormaps raise on re-register register with force=True (idempotent)

CI

Adds tqdm and rich to the testing extra so the progress-bar callback tests (TestTQDMProgressBar) actually run in CI instead of being silently skipped — this is what broke main after #90 merged.

Verification

Full suite green locally (pytest braintools --cov), total coverage 91.77%. Tests previously asserting the buggy behaviour were updated to assert the fixed behaviour.

Notes

Changelog updated under v0.1.10.

Closes #92, #93, #94, #95, #96, #97

Summary by Sourcery

Fix trainer checkpointing and device detection and resolve multiple visualize API crashes, adding regression tests and restoring CI progress-bar coverage.

Bug Fixes:

  • Ensure LightningModule.device works with array-backed parameters by handling jax.Array.devices() correctly.
  • Fix ModelCheckpoint to write checkpoints via the msgpack_save API so saves succeed and can be restored.
  • Update animate_2D to reshape inputs to the specified grid before drawing, preventing pcolor from crashing.
  • Make correlation_matrix(method='kendall') compute pairwise feature correlations instead of passing a 2-D matrix to kendalltau.
  • Correct remove_axis to operate on ax.spines so valid positions hide the intended axes without raising.
  • Make create_neural_colormap and brain_colormaps idempotent by forcing re-registration instead of raising on reuse.

CI:

  • Add tqdm and rich to the testing extra so progress-bar callback tests run in CI instead of being skipped.

Documentation:

  • Update changelog entries to document trainer and visualize bug fixes and CI testing changes.

Tests:

  • Add regression tests for LightningModule.device with array parameters and for real ModelCheckpoint save/load roundtrips.
  • Extend visualize tests to cover animate_2D execution, remove_axis spine hiding, Kendall correlation matrices, and idempotent colormap registration.

Fixes bugs surfaced while raising test coverage:

- LightningModule.device: jax.Array.devices() returns a set, which is
  not subscriptable; take an arbitrary element via next(iter(...)). (#92)
- animate_2D: reshape the (num_step, num_neuron) values to the
  (height, width) grid before FuncAnimation, instead of relying on a
  later reshape that never ran. (#93)
- correlation_matrix(method='kendall'): kendalltau only compares two
  1-D samples; build the feature-by-feature matrix pairwise. (#94)
- ModelCheckpoint._save_checkpoint: use file.msgpack_save to write the
  checkpoint instead of msgpack_from_state_dict (a restorer). (#95)
- remove_axis: ax.spine -> ax.spines (AttributeError otherwise). (#96)
- create_neural_colormap/brain_colormaps: register with force=True so
  re-registering an existing name is idempotent instead of raising. (#97)

Also add tqdm and rich to the testing extra so the progress-bar
callback tests run in CI, and update the affected tests to assert the
fixed behaviour. Changelog updated.

Closes #92, #93, #94, #95, #96, #97
@sourcery-ai

sourcery-ai Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Reviewer's Guide

This PR fixes six latent public API crashes in the trainer and visualize modules, corrects the ModelCheckpoint serialization path, and makes a few visualization utilities more robust and idempotent, while adding targeted regression tests and ensuring progress-bar tests run in CI by updating test dependencies.

Sequence diagram for ModelCheckpoint msgpack_save path

sequenceDiagram
    participant Trainer
    participant ModelCheckpoint
    participant Module
    participant braintools_file

    Trainer->>ModelCheckpoint: _save_checkpoint(trainer, module, filepath)
    ModelCheckpoint->>Module: state_dict()
    ModelCheckpoint->>Module: current_epoch
    ModelCheckpoint->>Module: global_step
    ModelCheckpoint->>braintools_file: msgpack_save(filepath, checkpoint, verbose=False)
Loading

File-Level Changes

Change Details Files
Fix LightningModule.device for array-backed parameters so it no longer crashes on jax.Array.devices() returning a set.
  • Change device property to iterate over the devices set instead of indexing it, returning an arbitrary device if available.
  • Add a regression test that constructs a LightningModule with a brainstate.ParamState wrapping a JAX array and asserts device is non-None.
braintools/trainer/_module.py
braintools/trainer/_module_test.py
Correct ModelCheckpoint to use the proper msgpack writer API and adjust tests to match the new signature.
  • Replace use of msgpack_from_state_dict in _save_checkpoint with msgpack_save(filepath, checkpoint, verbose=False).
  • Update callback tests to monkeypatch bf.msgpack_save with a fake(path, target, *args, **kwargs) stub and validate saved checkpoint contents, including a new real roundtrip test using msgpack_load.
  • Update integration test to patch bf.msgpack_save instead of msgpack_from_state_dict for faster runs.
braintools/trainer/_callbacks.py
braintools/trainer/_callbacks_test.py
Make animate_2D work with 2D value arrays by reshaping once up front and add a regression test.
  • Convert values to a NumPy array and reshape to (num_step, height, width) before computing min/max and before the frame function uses it.
  • Remove the later reshape call inside the function to avoid double reshaping.
  • Add a TestAnimate2D test case that generates random data, calls animate_2D with show=False, and asserts it returns a non-None animation object.
  • Ensure animate_2D is imported into the plots extra tests module.
braintools/visualize/_plots.py
braintools/visualize/_plots_extra_test.py
Fix correlation_matrix(method='kendall') to build a proper feature-by-feature matrix rather than calling kendalltau on a 2D array.
  • Implement pairwise Kendall’s tau computation over feature columns, initializing an identity matrix and filling off-diagonal entries with tau values.
  • Update the kendall branch test to expect a valid Axes return instead of an exception, confirming the new behavior.
braintools/visualize/_statistical.py
braintools/visualize/_statistical_extra_test.py
Correct remove_axis to operate on Matplotlib’s spines collection and test that spines are actually hidden.
  • Change ax.spine[p].set_visible(False) to ax.spines[p].set_visible(False) inside remove_axis.
  • Update the remove_axis test to call remove_axis with valid side names and assert that the corresponding spines are no longer visible.
braintools/visualize/_plots.py
braintools/visualize/_plots_extra_test.py
Make neural colormap registration idempotent and cover repeat registration in tests.
  • Register colormaps with plt.colormaps.register(cmap, name=name, force=True) so reusing a name overrides without raising.
  • Add a regression test that calls create_neural_colormap twice with the same name and asserts the colormap is present.
  • Changelog entry describing the idempotent behavior for create_neural_colormap/brain_colormaps.
braintools/visualize/_colormaps.py
braintools/visualize/_colormaps_extra_test.py
changelog.md
Ensure progress-bar callback tests run in CI and document the dependency change.
  • Add tqdm and rich to the testing extra in pyproject.toml so TestTQDMProgressBar and related tests have their runtime dependencies.
  • Update changelog infrastructure section to mention the new testing dependencies and their role in coverage/CI.
pyproject.toml
changelog.md

Assessment against linked issues

Issue Objective Addressed Explanation
#92 Update LightningModule.device in braintools/trainer/_module.py so that it correctly handles JAX Array.devices() returning a set (and no longer raises TypeError when accessing .device).

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625 chaoming0625 merged commit b8ae6bd into main Jun 9, 2026
2 of 5 checks passed
@chaoming0625 chaoming0625 deleted the fix-latent-bugs branch June 9, 2026 16:29

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've left some high level feedback:

  • In animate_2D, consider adding an explicit check that num_neuron == height * width before reshaping so shape mismatches fail fast with a clear error instead of silently producing incorrect grids.
  • The new Kendall correlation path builds the matrix in nested Python loops; if this will be used with many features, consider a more efficient approach (e.g., precomputing columns and/or vectorizing the pairwise calls) to avoid quadratic Python overhead.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In `animate_2D`, consider adding an explicit check that `num_neuron == height * width` before reshaping so shape mismatches fail fast with a clear error instead of silently producing incorrect grids.
- The new Kendall correlation path builds the matrix in nested Python loops; if this will be used with many features, consider a more efficient approach (e.g., precomputing columns and/or vectorizing the pairwise calls) to avoid quadratic Python overhead.

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

chaoming0625 added a commit that referenced this pull request Jun 9, 2026
…) (#100)

np.trapz was renamed to np.trapezoid in NumPy 2.0 and removed in NumPy
2.4, so roc_curve and precision_recall_curve raised AttributeError on
NumPy 2.4+ (which CI installs). Resolve np.trapezoid when available and
fall back to np.trapz for the declared numpy>=1.15 floor.

This broke main CI after #98 merged; it was invisible locally because
local NumPy was < 2.4.

Closes #99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LightningModule.device raises TypeError ('set' object is not subscriptable)

1 participant