fix(trainer,visualize): correct six latent API bugs (#92-#97)#98
Merged
Conversation
Fixes bugs surfaced while raising test coverage: - LightningModule.device: jax.Array.devices() returns a set, which is not subscriptable; take an arbitrary element via next(iter(...)). (#92) - animate_2D: reshape the (num_step, num_neuron) values to the (height, width) grid before FuncAnimation, instead of relying on a later reshape that never ran. (#93) - correlation_matrix(method='kendall'): kendalltau only compares two 1-D samples; build the feature-by-feature matrix pairwise. (#94) - ModelCheckpoint._save_checkpoint: use file.msgpack_save to write the checkpoint instead of msgpack_from_state_dict (a restorer). (#95) - remove_axis: ax.spine -> ax.spines (AttributeError otherwise). (#96) - create_neural_colormap/brain_colormaps: register with force=True so re-registering an existing name is idempotent instead of raising. (#97) Also add tqdm and rich to the testing extra so the progress-bar callback tests run in CI, and update the affected tests to assert the fixed behaviour. Changelog updated. Closes #92, #93, #94, #95, #96, #97
Contributor
Reviewer's GuideThis PR fixes six latent public API crashes in the trainer and visualize modules, corrects the ModelCheckpoint serialization path, and makes a few visualization utilities more robust and idempotent, while adding targeted regression tests and ensuring progress-bar tests run in CI by updating test dependencies. Sequence diagram for ModelCheckpoint msgpack_save pathsequenceDiagram
participant Trainer
participant ModelCheckpoint
participant Module
participant braintools_file
Trainer->>ModelCheckpoint: _save_checkpoint(trainer, module, filepath)
ModelCheckpoint->>Module: state_dict()
ModelCheckpoint->>Module: current_epoch
ModelCheckpoint->>Module: global_step
ModelCheckpoint->>braintools_file: msgpack_save(filepath, checkpoint, verbose=False)
File-Level Changes
Assessment against linked issues
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
Contributor
There was a problem hiding this comment.
Hey - I've left some high level feedback:
- In
animate_2D, consider adding an explicit check thatnum_neuron == height * widthbefore reshaping so shape mismatches fail fast with a clear error instead of silently producing incorrect grids. - The new Kendall correlation path builds the matrix in nested Python loops; if this will be used with many features, consider a more efficient approach (e.g., precomputing columns and/or vectorizing the pairwise calls) to avoid quadratic Python overhead.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `animate_2D`, consider adding an explicit check that `num_neuron == height * width` before reshaping so shape mismatches fail fast with a clear error instead of silently producing incorrect grids.
- The new Kendall correlation path builds the matrix in nested Python loops; if this will be used with many features, consider a more efficient approach (e.g., precomputing columns and/or vectorizing the pairwise calls) to avoid quadratic Python overhead.Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
This was referenced Jun 9, 2026
chaoming0625
added a commit
that referenced
this pull request
Jun 9, 2026
…) (#100) np.trapz was renamed to np.trapezoid in NumPy 2.0 and removed in NumPy 2.4, so roc_curve and precision_recall_curve raised AttributeError on NumPy 2.4+ (which CI installs). Resolve np.trapezoid when available and fall back to np.trapz for the declared numpy>=1.15 floor. This broke main CI after #98 merged; it was invisible locally because local NumPy was < 2.4. Closes #99
This was referenced Jun 10, 2026
Closed
Closed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes six latent API bugs surfaced while raising braintools test coverage above 90% (#90). Each bug is a hard crash (
TypeError/AttributeError/wrong-API) on a documented public call path; all are covered by the regression tests added here.LightningModule.device→TypeError: 'set' object is not subscriptablejax.Array.devices()returns a set; takenext(iter(...))visualize.animate_2D→ pcolor called on a 1-D array(height, width)grid beforeFuncAnimationvisualize.correlation_matrix(method='kendall')→kendalltaurejects a 2-D matrixtrainer.ModelCheckpointcannot save — usesmsgpack_from_state_dict(a restorer) as a writerfile.msgpack_savevisualize.remove_axis→AttributeError: 'Axes' object has no attribute 'spine'ax.spine→ax.spinescreate_neural_colormap/brain_colormapsraise on re-registerforce=True(idempotent)CI
Adds
tqdmandrichto thetestingextra so the progress-bar callback tests (TestTQDMProgressBar) actually run in CI instead of being silently skipped — this is what brokemainafter #90 merged.Verification
Full suite green locally (
pytest braintools --cov), total coverage 91.77%. Tests previously asserting the buggy behaviour were updated to assert the fixed behaviour.Notes
Changelog updated under v0.1.10.
Closes #92, #93, #94, #95, #96, #97
Summary by Sourcery
Fix trainer checkpointing and device detection and resolve multiple visualize API crashes, adding regression tests and restoring CI progress-bar coverage.
Bug Fixes:
CI:
Documentation:
Tests: