Skip to content

LC2ST Module Refactoring#1727

Open
janfb wants to merge 13 commits intomainfrom
refactor-lc2st
Open

LC2ST Module Refactoring#1727
janfb wants to merge 13 commits intomainfrom
refactor-lc2st

Conversation

@janfb
Copy link
Copy Markdown
Contributor

@janfb janfb commented Jan 14, 2026

LC2ST Module Refactoring

Summary

This PR refactors the LC2ST diagnostics module to improve API clarity, error handling, and maintainability while preserving full backward compatibility.

Note 1: I used Claude Code with this refactoring. Every commit was reviewed by me and a very critical Gemini Pro reviewing agent. I believe the changes here are very useful and correct, but the reviewer should please take extra care and be extra skeptical when reviewing this.

Note 2: I started this in parallel or motivated by the flaky test spotted in #1715 , so we need to resolve potential conflicts between these PRs down the line.

Motivation

The LC2ST implementation had accumulated several issues:

  1. Confusing API workflow - Users could call methods in arbitrary order with silent failures or cryptic errors
  2. Assertions for validation - Input validation used assert statements, which are stripped in optimized builds
  3. Code duplication - Z-score normalization logic was repeated in 4+ locations
  4. Unstructured returns - get_scores() returned different types based on a boolean flag
  5. Double-normalization bug - Data was normalized twice in null hypothesis training paths

Design Choices

State Machine (LC2STState enum): Rather than documenting method call order, the refactoring enforces it through explicit state transitions. The LC2ST object progresses through INITIALIZED → OBSERVED_TRAINED/NULL_TRAINED → READY, with methods checking state before execution. Both training orders are valid and reach READY.

Structured Returns (LC2STScores dataclass): Replaces the Union[array, Tuple[array, array]] return type with a dataclass containing scores and optional probabilities. The old return_probs=True behavior is preserved but deprecated.

Parameter Rename with Deprecation: thetasprior_samples for clarity (it's samples from the prior, not parameters). The old parameter works via keyword-only argument with DeprecationWarning.

DRY Normalization: Extracted _normalize_theta() and _normalize_x() helpers. Normalization now happens exactly once in _train() and get_scores(), fixing the double-normalization bug.

Validation via Exceptions: All assert statements replaced with ValueError/TypeError with actionable messages citing actual vs expected values.

Backward Compatibility

  • Positional arguments work unchanged
  • thetas=... keyword works with deprecation warning
  • get_scores(..., return_probs=True) works with deprecation warning
  • All existing tests pass without modification

Test Improvements

Reduced parametrization from 48 combinations to 28 focused tests using pytest patterns (dataclass fixtures, composable fixture hierarchy, parametrized validation tests).

@codecov
Copy link
Copy Markdown

codecov bot commented Jan 14, 2026

Codecov Report

❌ Patch coverage is 84.42211% with 31 lines in your changes missing coverage. Please review.
✅ Project coverage is 87.88%. Comparing base (173cb7c) to head (32a346a).
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
sbi/diagnostics/lc2st.py 84.34% 31 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1727      +/-   ##
==========================================
- Coverage   87.94%   87.88%   -0.07%     
==========================================
  Files         140      140              
  Lines       12845    12962     +117     
==========================================
+ Hits        11297    11392      +95     
- Misses       1548     1570      +22     
Flag Coverage Δ
fast 82.62% <84.42%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/diagnostics/__init__.py 100.00% <100.00%> (ø)
sbi/diagnostics/lc2st.py 89.34% <84.34%> (-4.32%) ⬇️

@janfb janfb requested a review from JuliaLinhart January 15, 2026 10:57
…tionality

- Updated parameter names in LC2ST initialization for consistency (thetas -> prior_samples).
- Modified get_scores and get_statistics_under_null_hypothesis methods to return LC2STScores objects, encapsulating probabilities and scores.
- Adjusted usage of get_scores in the tutorial and tests to reflect the new return type.
- input validation in LC2ST to prevent indexing errors.
- Updated tests to assert the presence of scores in the returned null statistics.
@janfb janfb requested review from plcrodrigues and removed request for JuliaLinhart January 22, 2026 08:46
Resolved conflicts integrating GPU support (PR #1715) into the
refactored LC2ST code:
- Added device parameter to LC2ST.__init__ with GPU detection
- Integrated NeuralNetBinaryClassifier for GPU-accelerated training
- Added skorch ValidSplit and EarlyStopping configuration
- Fixed target dtype for skorch (float32) vs sklearn (int64)
- Updated tests to use refactored cal_data fixture

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@manuelgloeckler manuelgloeckler mentioned this pull request Mar 26, 2026
3 tasks
Copy link
Copy Markdown
Contributor

@manuelgloeckler manuelgloeckler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, looks overall good thanks for refactoring.

I do have some minor remarks, but otherwise looks ready to merge from my point of view.

I did run slow tests and they pass.

The current version of the notebook however fails i.e. in Quantitative diagnostics chapter it fails on

quantiles = np.quantile(T_null, [0, 1-conf_alpha])

because T_null is no longer an array but a LC2STScores object.

Comment thread sbi/diagnostics/lc2st.py
# Set the parameters for the null hypothesis testing
self.null_distribution = flow_base_dist
self.permutation = False
self.trained_clfs_null = trained_clfs_null
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhh, not entirely sure if LC2ST_NF(trained_clfs_null=...) will work.

Mostly because __init__ is called normally, hence _state is left as INITIALIZED, no?

If a pretrained model is given, shouldn't it be READY.

Comment thread sbi/diagnostics/lc2st.py Outdated
)
self.trained_clfs = trained_clfs

# Update state
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhh, the unconditional else can be dangerous here, no?

For example:

lc2st.train_under_null_hypothesis()
lc2st.train_on_observed_data()
lc2st.p_value(...)  # works
lc2st.train_on_observed_data()  # retrain observed classifier for some reason (a bit artifitial)
lc2st.p_value(...)  # now raises

If its ready, shouldnt it stay ready?

Comment thread sbi/diagnostics/lc2st.py Outdated
trained_clfs: List[BaseEstimator],
return_probs: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
) -> Union[LC2STScores, np.ndarray, Tuple[np.ndarray, np.ndarray]]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.ndarray type is never returned, no?

Also not sure if its necessary to add this backward compatibility here. Returning the structured object will already break it (i.e. get_score(...).mean() or so will no longer work).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, this will break anyway. I changed it to FutureWarnings to make sure they will be raised. we can then remove this in the next release after this coming one.

Comment thread sbi/diagnostics/lc2st.py Outdated
@@ -499,19 +809,28 @@
defaults to False.
verbosity: Verbosity level, defaults to 1.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but defaults to 0

Comment thread sbi/diagnostics/lc2st.py
Normalized theta if z_score is enabled, otherwise unchanged theta.
"""
if self.z_score:
return (theta - self.theta_p_mean) / self.theta_p_std
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It maybe make sense to make this more robust i.e. this could go into a divide by zero on constant params/features.

Comment thread sbi/diagnostics/lc2st.py
Normalized x if z_score is enabled, otherwise unchanged x.
"""
if self.z_score:
return (x - self.x_p_mean) / self.x_p_std
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

janfb added 4 commits April 18, 2026 20:56
Fixes three merge blockers from review:

- LC2ST_NF(trained_clfs_null=...) was deadlocked: __init__ left _state
  at INITIALIZED when pretrained null classifiers were passed, so
  train_on_observed_data advanced to OBSERVED_TRAINED (not READY) and
  p_value raised. Now advances to NULL_TRAINED when pretrained
  classifiers are supplied.

- train_on_observed_data downgraded READY -> OBSERVED_TRAINED on
  retrain, breaking the documented loop-over-seeds workflow from the
  tutorial. READY is now preserved.

- Advanced-tutorial notebook called np.quantile / axes.hist on the new
  LC2STScores return value. Both failing cells now extract .scores.
- Narrow get_scores / get_statistics_under_null_hypothesis return type
  to Union[LC2STScores, Tuple[np.ndarray, np.ndarray]]; the bare
  np.ndarray branch was never returned.
- Clarify return_probs deprecation warning to mention the (probs,
  scores) tuple order and the eventual removal.
- Guard z-score normalization against constant feature dimensions:
  std == 0 is replaced by 1.0 so constant columns become pass-through
  (mean-centered) instead of producing NaN/Inf.
- Rewrite the error message raised when re-entering
  train_under_null_hypothesis so it applies cleanly to both LC2ST
  (permutation, data-dependent) and LC2ST_NF (analytical, reusable).
- Fix docstring: verbosity in get_statistics_under_null_hypothesis
  defaults to 0, not 1.
- Add regression tests for single-normalization in null training,
  constant-dim normalization robustness, and document the
  lc2st_instance fixture scope.
@janfb
Copy link
Copy Markdown
Contributor Author

janfb commented Apr 18, 2026

Good catches @manuelgloeckler , thanks!

I fixed them all and will merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants