Skip to content

Move corrector_disabled_epochs scheduling into corrector#1266

Merged
elynnwu merged 4 commits into
feature/disable-corrector-first-epochsfrom
feature/disable-corrector-top-level
Jun 15, 2026
Merged

Move corrector_disabled_epochs scheduling into corrector#1266
elynnwu merged 4 commits into
feature/disable-corrector-first-epochsfrom
feature/disable-corrector-top-level

Conversation

@elynnwu

@elynnwu elynnwu commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

These changes move corrector_disabled_epochs out of SingleModuleStep and into the shared corrector configuration layer. Correctors are now wrapped by EpochScheduledCorrector, which can disable any corrector during the first N training epochs while still applying it during eval/validation/inference.
This makes the behavior available to atmosphere, ocean, ice, and future correctors without copying logic into each Step. Step train/eval, epoch, and checkpoint state now forward through the corrector lifecycle so mid-epoch resume preserves the scheduler state.

Resolves #1261

"""Called by the stepper at the start of each training epoch."""
pass

def get_state(self) -> dict[str, Any]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Make all methods on this ABC @abstract (only implemented on the children) or @final (only implementable on the base). Likely they should all be abstract. It will help to not have to scan back and forth between this ABC and the subclasses to understand which methods are implemented where.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Comment thread fme/core/step/radiation.py Outdated
self.module.load_state(state["module"])
self.radiation_module.load_state(state["radiation_module"])
if "corrector" in state:
self._load_corrector_state(state["corrector"])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Would you want to call _load_corrector_state({}) if the saved corrector had no state? For example to raise if somehow we had a has-state corrector and were trying to load a no-corrector-state checkpoint?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, udpated

self.call_count = 0
self.seen_states: list[CorrectorState | None] = []

def train(self, mode: bool = True) -> "_RecordingCorrector":

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Should there be tests exercising these new methods? I didn't notice earlier but I think an AI reviewer would call it out.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Comment thread fme/core/corrector/registry.py Outdated

@dataclasses.dataclass
class CorrectorConfigABC(abc.ABC):
corrector_disabled_epochs: int = dataclasses.field(default=0, kw_only=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chore: Remove this property from the ABC, right? Is it required for all CorrectorConfig?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@mcgibbon mcgibbon left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, one possible nit/chore to look at.

Comment thread fme/core/step/step.py Outdated
"""Set the step (and all submodules) to evaluation mode."""
return self.train(False)

def _set_corrector_train_mode(self, mode: bool) -> None:

@mcgibbon mcgibbon Jun 12, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getattr is a red flag, and I'm realizing from it that I didn't notice this is on the ABC. The corrector is specific to the subclass Step, shouldn't this code go there? That means train should be abstract instead of final - you could make a case for that one making an exception and allowing a default, but if you do that, at least use the super-call version to do the base part instead of copy-pasting.

@elynnwu elynnwu merged commit 6046ad5 into feature/disable-corrector-first-epochs Jun 15, 2026
1 check passed
@elynnwu elynnwu deleted the feature/disable-corrector-top-level branch June 15, 2026 21:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants