Skip to content

Refactor NPE-C loss logic to Strategy Pattern#1755

Open
Sumit6307 wants to merge 4 commits intosbi-dev:mainfrom
Sumit6307:refactor/npe-c-strategy
Open

Refactor NPE-C loss logic to Strategy Pattern#1755
Sumit6307 wants to merge 4 commits intosbi-dev:mainfrom
Sumit6307:refactor/npe-c-strategy

Conversation

@Sumit6307
Copy link
Copy Markdown

Summary

This PR refactors the NPE_C trainer to use the Strategy Pattern for its loss calculation logic. It isolates the "atomic" (sample-based) and "non-atomic" (analytical Gaussian) loss implementations into separate strategy classes, significantly improving the readability and maintainability of the core trainer class.

Motivation

The NPE_C class previously contained complex, non-trivial logic for switching between two very different loss calculation methods. This tightly coupled the mathematical details of the loss functions with the training loop. Validating the Single Responsibility Principle, this refactor:

  • Decouples loss implementation from the training orchestration.
  • Makes it easier to extend NPE_C with new loss types in the future (e.g., different divergences) without modifying the main loop.
  • Improves code clarity by removing over 100 lines of mixed-concern code from npe_c.py.

Changes

  • New File: sbi/inference/trainers/npe/npe_c_loss.py
    • Implemented AtomicLoss: Encapsulates the sampling-based loss logic.
    • Implemented NonAtomicGaussianLoss: Encapsulates the analytical MoG loss logic including automatic posterior transformation.
  • Modified: sbi/inference/trainers/npe/npe_c.py
    • Removed private methods _log_prob_proposal_posterior_atomic, _log_prob_proposal_posterior_mog, and _automatic_posterior_transformation.
    • Instantiates the appropriate strategy in train() based on the proposal and prior type.
    • Delegates execution to the strategy object.

Verification

  • Logic Preservation: Manually verified that the mathematical operations and logic flow in the new strategy classes match the original implementation exactly line-by-line.
  • Static Analysis: Verified imports and syntax correctness.
  • Behavior: The refactor is purely structural; no changes were made to the underlying mathematical definitions of the NPE-C loss.

@Sumit6307
Copy link
Copy Markdown
Author

Hi @janfb, please have a look at this PR. Thanks!

@janfb
Copy link
Copy Markdown
Contributor

janfb commented Feb 6, 2026

Hi @Sumit6307 , thank you for this PR as well. I assume you have read my comment under the PR (#1756 (comment)). That said, this contribution looks very good on a high level. I suggest the following:

@Sumit6307
Copy link
Copy Markdown
Author

Hi @Sumit6307 , thank you for this PR as well. I assume you have read my comment under the PR (#1756 (comment)). That said, this contribution looks very good on a high level. I suggest the following:

@janfb
Hi, thank you for the feedback!

Yes, I have read your comment under #1756. I’m glad to hear that the contribution looks good at a high level.

I’ll go through #1241 and @michaeldeistler’s proposal in detail and review how my PR aligns with the planned new organization of SNPE methods. Based on that, I’ll update the implementation and clarify the alignment where needed.

I’ll also run ruff and pyright locally and fix the remaining linting and type-checking issues so that CI passes.

Please let me know if you’d prefer me to continue iterating on this within the current PR, or if it would be better to first open or move the discussion to a dedicated issue.

Thanks for the guidance.

@janfb
Copy link
Copy Markdown
Contributor

janfb commented Feb 6, 2026

Sounds good @Sumit6307 , let's discuss the implementation here in the PR.

@janfb
Copy link
Copy Markdown
Contributor

janfb commented Mar 16, 2026

Hi @Sumit6307 are you planning to continue working on this? If not, no problem, please let us know 🙏

@Sumit6307
Copy link
Copy Markdown
Author

Hi @Sumit6307 are you planning to continue working on this? If not, no problem, please let us know 🙏

@janfb
"Hi! Yes, I am definitely continuing to work on this.

I have just pushed a major update to the branch that:

1 . Resolves all merge conflicts with the latest main branch.
2. Completes the Strategy Pattern refactor: The loss calculation logic is now isolated into AtomicLoss and NonAtomicGaussianLoss classes in a new npe_c_loss.py module.
3. Aligns with the new architecture: I've updated the implementation to use the new MixtureDensityEstimator, MoG, and Tracker interfaces.
4 . Fixes CI/Linting: All previous ruff linting and formatting issues have been resolved.

The PR is now ready for a fresh review. Thank you for your patience! 🙏"

@janfb
Copy link
Copy Markdown
Contributor

janfb commented Mar 23, 2026

Hi @Sumit6307, thank you for the update and for your continued work on this!

I took this reviewing as an opportunity to think more broadly about how we handle NPE
losses and what were thinking of back then when creating this issue.

Overview

Your extraction of the atomic and non-atomic loss into separate classes is the right
direction, we've been wanting to do this (see #1241). After reviewing your suggestions,
I made a more comprehensive refactoring plan that covers not just NPE-C but also NPE-B
(importance-weighted loss) and the base class dispatch. The core idea is:

  1. A NPELossStrategy protocol: light weight strategy objects that all share __call__(theta, x, masks, proposal) -> Tensor and a uses_only_latest_round property (replacing the use_non_atomic_loss flag and the hasattr(self, "_ran_final_round") check in _get_start_index)
  2. Three concrete strategies in a single npe_loss.py module: AtomicLoss,
    NonAtomicGaussianLoss, ImportanceWeightedLoss, the round 0 MLE loss has no
    strategy, it's the fall back (just a 3-liner).
  3. User-composable losses: train() accepts an optional loss_strategy parameter, enabling the API from New organization of SNPE methods #1241: trainer.train(loss_strategy=AtomicLoss(...))
  4. Eliminating the "sneaky trick": the scattered instance attributes (_num_atoms, _use_combined_loss, use_non_atomic_loss, etc.) get replaced by a single typed strategy object on self._loss_strategy
  5. Removing the abstract _log_prob_proposal_posterior from the base class entirely — the strategy object replaces this dispatch mechanism

What this means for your PR

Your work identified the right abstraction boundary and showed the extraction is feasible. However, the scope of what we want to do is larger than what this PR covers, and merging it as-is would mean immediately refactoring it again. So I'd like to offer you two options:

(a) You take ownership of Phase 1 of the plan. This would mean updating your PR to:

  • Use a Protocol instead of an ABC (so users can bring their own loss callables)
  • Add uses_only_latest_round: bool to each strategy class (replaces use_non_atomic_loss flag)
  • Fix the type constraint on AtomicLoss: it should accept ConditionalDensityEstimator, not MixtureDensityEstimator (atomic loss works with flows, not just MDNs)
  • Preserve all mathematical docstrings when moving code (the Greenberg et al. 2019
    Appendix A1 derivations, the z-scoring explanation, the formula comments; it seems you
    removed those on purpose in your current diff)
  • Add ImportanceWeightedLoss (extracted from NPE_B)
  • Add unit tests for the strategy classes
  • Rename the file to npe_loss.py (it's no longer NPE-C-specific)
  • Don't instantiate a strategy for round 0 (the base class _loss() uses MLE directly)

I'm happy to provide more detailed guidance on any of these points.

(b) We take it from here. This would be totally fine as well. You've done valuable exploratory work. We'd credit you as co-author in the commits.

Technical issues with the current PR (for reference)

In case you go with option (a), here are the specific issues to address:

  • NPELoss.__init__ types neural_net as MixtureDensityEstimator, but AtomicLoss works with any ConditionalDensityEstimator
  • The hasattr(self, "_loss_strategy") check is fragile — initialize as None in __init__ and check explicitly
  • The lazy import inside train() is unnecessary — no circular import risk
  • The assert_all_finite error message for the MoG path was truncated — the original "This is likely due to a numerical instability... Please create an issue on Github" helps users diagnose problems
  • CI hasn't run on this PR — please verify with pytest -m "not slow and not gpu", ruff check sbi/, and pyright sbi/

Let me know which option you'd prefer, and thanks again for pushing this forward!

P.S.: In case you were planning to apply for GSoC with us, just to manage expectations,
the deadline is approaching soon and we have quite a number of applications and
proposals to review already, so it's unlikely we will find time to review more this
week.

@Sumit6307
Copy link
Copy Markdown
Author

Hi @Sumit6307, thank you for the update and for your continued work on this!

I took this reviewing as an opportunity to think more broadly about how we handle NPE losses and what were thinking of back then when creating this issue.

Overview

Your extraction of the atomic and non-atomic loss into separate classes is the right direction, we've been wanting to do this (see #1241). After reviewing your suggestions, I made a more comprehensive refactoring plan that covers not just NPE-C but also NPE-B (importance-weighted loss) and the base class dispatch. The core idea is:

  1. A NPELossStrategy protocol: light weight strategy objects that all share __call__(theta, x, masks, proposal) -> Tensor and a uses_only_latest_round property (replacing the use_non_atomic_loss flag and the hasattr(self, "_ran_final_round") check in _get_start_index)
  2. Three concrete strategies in a single npe_loss.py module: AtomicLoss,
    NonAtomicGaussianLoss, ImportanceWeightedLoss, the round 0 MLE loss has no
    strategy, it's the fall back (just a 3-liner).
  3. User-composable losses: train() accepts an optional loss_strategy parameter, enabling the API from New organization of SNPE methods #1241: trainer.train(loss_strategy=AtomicLoss(...))
  4. Eliminating the "sneaky trick": the scattered instance attributes (_num_atoms, _use_combined_loss, use_non_atomic_loss, etc.) get replaced by a single typed strategy object on self._loss_strategy
  5. Removing the abstract _log_prob_proposal_posterior from the base class entirely — the strategy object replaces this dispatch mechanism

What this means for your PR

Your work identified the right abstraction boundary and showed the extraction is feasible. However, the scope of what we want to do is larger than what this PR covers, and merging it as-is would mean immediately refactoring it again. So I'd like to offer you two options:

(a) You take ownership of Phase 1 of the plan. This would mean updating your PR to:

  • Use a Protocol instead of an ABC (so users can bring their own loss callables)
  • Add uses_only_latest_round: bool to each strategy class (replaces use_non_atomic_loss flag)
  • Fix the type constraint on AtomicLoss: it should accept ConditionalDensityEstimator, not MixtureDensityEstimator (atomic loss works with flows, not just MDNs)
  • Preserve all mathematical docstrings when moving code (the Greenberg et al. 2019
    Appendix A1 derivations, the z-scoring explanation, the formula comments; it seems you
    removed those on purpose in your current diff)
  • Add ImportanceWeightedLoss (extracted from NPE_B)
  • Add unit tests for the strategy classes
  • Rename the file to npe_loss.py (it's no longer NPE-C-specific)
  • Don't instantiate a strategy for round 0 (the base class _loss() uses MLE directly)

I'm happy to provide more detailed guidance on any of these points.

(b) We take it from here. This would be totally fine as well. You've done valuable exploratory work. We'd credit you as co-author in the commits.

Technical issues with the current PR (for reference)

In case you go with option (a), here are the specific issues to address:

  • NPELoss.__init__ types neural_net as MixtureDensityEstimator, but AtomicLoss works with any ConditionalDensityEstimator
  • The hasattr(self, "_loss_strategy") check is fragile — initialize as None in __init__ and check explicitly
  • The lazy import inside train() is unnecessary — no circular import risk
  • The assert_all_finite error message for the MoG path was truncated — the original "This is likely due to a numerical instability... Please create an issue on Github" helps users diagnose problems
  • CI hasn't run on this PR — please verify with pytest -m "not slow and not gpu", ruff check sbi/, and pyright sbi/

Let me know which option you'd prefer, and thanks again for pushing this forward!

P.S.: In case you were planning to apply for GSoC with us, just to manage expectations, the deadline is approaching soon and we have quite a number of applications and proposals to review already, so it's unlikely we will find time to review more this week.

Hi @janfb,

Thank you for the detailed feedback and the broader context regarding #1241!

I agree that a clean NPELossStrategy protocol and a unified npe_loss.py module will greatly simplify the architecture across NPE-A, NPE-B, and NPE-C. I would be happy to take ownership of Option (a) and implement Phase 1.

I have just pushed the commits for Phase 1! The changes include:

  1. Created the NPELossStrategy Protocol (replacing ABC).
  2. Extracted ImportanceWeightedLoss, AtomicLoss, and NonAtomicGaussianLoss into the shared npe_loss.py.
  3. Updated the trainers (NPE_C, NPE_B, NPE_A) to use the new strategies.
  4. Removed the abstract method _log_prob_proposal_posterior from the base class.
  5. Restored testing and the mathematical derivations in the docstrings for NonAtomicGaussianLoss.
  6. Included basic unit tests for strategy initializations in tests/inference/npe_loss_test.py.

Regarding GSoC: I completely understand you are all very busy with the upcoming deadline and the high volume of proposals, so please do not feel rushed to review this PR right away! I am officially submitting my GSoC proposal to NumFOCUS for sbi tomorrow. I have really enjoyed continuously contributing to the sbi-dev/sbi repository and open-source over these past weeks. Through this PR and my previous work, I feel I have gained a deep, hands-on understanding of the sbi codebase and its architecture. I would be absolutely thrilled if my proposal is accepted, as I am fully prepared and highly motivated to dedicate my summer to continuing this work!

Thank you again for all the valuable guidance you've provided so far.

@janfb
Copy link
Copy Markdown
Contributor

janfb commented Apr 16, 2026

Hi @Sumit6307

Thanks for the update and sorry for the delayed response--we had a high load of PRs and other maintenance work recently due to GSoC.

I will review this soon and I have also seen your PR #1826 on the analog refactor for NRE, that's great!

Thanks for your patience 🙏
Jan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants