Move corrector_disabled_epochs scheduling into corrector#1266
Conversation
| """Called by the stepper at the start of each training epoch.""" | ||
| pass | ||
|
|
||
| def get_state(self) -> dict[str, Any]: |
There was a problem hiding this comment.
Suggestion: Make all methods on this ABC @abstract (only implemented on the children) or @final (only implementable on the base). Likely they should all be abstract. It will help to not have to scan back and forth between this ABC and the subclasses to understand which methods are implemented where.
| self.module.load_state(state["module"]) | ||
| self.radiation_module.load_state(state["radiation_module"]) | ||
| if "corrector" in state: | ||
| self._load_corrector_state(state["corrector"]) |
There was a problem hiding this comment.
Question: Would you want to call _load_corrector_state({}) if the saved corrector had no state? For example to raise if somehow we had a has-state corrector and were trying to load a no-corrector-state checkpoint?
| self.call_count = 0 | ||
| self.seen_states: list[CorrectorState | None] = [] | ||
|
|
||
| def train(self, mode: bool = True) -> "_RecordingCorrector": |
There was a problem hiding this comment.
Question: Should there be tests exercising these new methods? I didn't notice earlier but I think an AI reviewer would call it out.
|
|
||
| @dataclasses.dataclass | ||
| class CorrectorConfigABC(abc.ABC): | ||
| corrector_disabled_epochs: int = dataclasses.field(default=0, kw_only=True) |
There was a problem hiding this comment.
Chore: Remove this property from the ABC, right? Is it required for all CorrectorConfig?
mcgibbon
left a comment
There was a problem hiding this comment.
LGTM, one possible nit/chore to look at.
| """Set the step (and all submodules) to evaluation mode.""" | ||
| return self.train(False) | ||
|
|
||
| def _set_corrector_train_mode(self, mode: bool) -> None: |
There was a problem hiding this comment.
getattr is a red flag, and I'm realizing from it that I didn't notice this is on the ABC. The corrector is specific to the subclass Step, shouldn't this code go there? That means train should be abstract instead of final - you could make a case for that one making an exception and allowing a default, but if you do that, at least use the super-call version to do the base part instead of copy-pasting.
6046ad5
into
feature/disable-corrector-first-epochs
These changes move
corrector_disabled_epochsout ofSingleModuleStepand into the shared corrector configuration layer. Correctors are now wrapped byEpochScheduledCorrector, which can disable any corrector during the first N training epochs while still applying it during eval/validation/inference.This makes the behavior available to atmosphere, ocean, ice, and future correctors without copying logic into each Step. Step train/eval, epoch, and checkpoint state now forward through the corrector lifecycle so mid-epoch resume preserves the scheduler state.
Resolves #1261