Refactor NRE loss logic to Strategy Pattern (Phase 2 of #1241)#1826
Refactor NRE loss logic to Strategy Pattern (Phase 2 of #1241)#1826Sumit6307 wants to merge 6 commits intosbi-dev:mainfrom
Conversation
❌ 15 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
JiwaniZakir
left a comment
There was a problem hiding this comment.
In npe_b.py, ImportanceWeightedLoss is constructed with neural_net=self._neural_net before super().train() is called (lines ~155–162). If retrain_from_scratch=True is passed, the parent's train() will rebuild self._neural_net as a new object, leaving ImportanceWeightedLoss holding a stale reference to the old network. This is a silent bug since Python won't raise an error — the loss strategy would just train against the discarded network weights.
Additionally, self._round = max(self._data_round_index) is assigned in npe_b.train() and presumably also in the parent's train(). This duplication is fragile: if the parent ever changes how _round is derived, the subclass override will silently diverge. It would be cleaner to initialize _loss_strategy lazily inside the parent's _loss() method or via a hook like _get_loss_strategy() that subclasses override, so the strategy is only created after the network is finalized.
The guard if len(self._data_round_index) == 0 also appears to duplicate validation already present in the parent, which may lead to inconsistent error messages across NPE variants if the check is updated in one place but not the other.
Summary
Building exactly on the maintainers' feedback and unified vision established in Phase 1 (NPE-C refactoring PR #1755), this PR executes Phase 2: fully extracting the NRE loss calculations (NRE_A, NRE_B, NRE_C, and BNRE) into composable, isolated strategies that conform to an NRELossStrategy Protocol.
By outsourcing AALRLoss, SRELoss, CNRELoss, and BNRELoss mathematically intensive routines to nre_loss.py, we entirely eliminate _loss() and _classifier_logits() from RatioEstimatorTrainer and its subclasses, creating a fully modular Ratio Estimation architecture.
Motivation
Presently, each NRE variant embeds complex classification and contrastive atom-generation logic tightly inside its respective _loss overridden method. As outlined in Option (a) of #1241, trainers should only orchestrate the training loop, while composable protocol-compliant Objects handle mathematical formulations.
Key Changes
NRE_A._loss)NRE_B._loss)NRE_C._loss)BNRE._loss)_loss_strategy: Optional[NRELossStrategy]inside _get_losses.Checklist
ruff checkformatting verified.