Skip to content

Add corrector regularization to training#1218

Draft
mcgibbon wants to merge 6 commits into
mainfrom
feature/corrector-regularization
Draft

Add corrector regularization to training#1218
mcgibbon wants to merge 6 commits into
mainfrom
feature/corrector-regularization

Conversation

@mcgibbon

@mcgibbon mcgibbon commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Add an optional training-time penalty on the magnitude of corrector adjustments, enabling L1/L2 regularization that discourages the model from relying on the corrector. Supported for both single-module and coupled training; when configured, the specified loss is computed between each corrector's per-variable correction (post-corrector minus pre-corrector) and a zero baseline in the loss-normalized space, then accumulated into the training loss.

Changes:

  • fme.core.step.step.StepResult: new dataclass replacing the bare TensorDict return type of StepABC.step; carries the denormalized output plus optional per-variable corrections. All StepABC implementations and consumers updated.

  • fme.core.step.single_module.step_with_adjustments: records corrections = post − pre when a corrector runs; MultiCallStep forwards corrections from its wrapped step.

  • fme.core.loss.CorrectorRegularizationConfig: new LossConfig + weight wrapper, with validation rejecting EnsembleLoss/NaN/global_mean_type.

  • fme.ace.stepper.single_module.TrainStepperConfig.corrector_regularization: optional field; built via Stepper.build_corrector_regularization_loss and applied per optimized step in TrainStepper._accumulate_loss. Per-step (`corrector_regularization_step_{i}`) and epoch-aggregated (`corrector_regularization`) metrics.

  • fme.coupled.stepper.ComponentTrainingConfig.corrector_regularization: per-realm version. `CoupledStepperTrainLoss.compute_corrector_regularization` wired into `CoupledTrainStepper._accumulate_step_loss`; emits `loss/{realm}corrector_regularization_step{i}`.

  • `fme.ace` re-exports `LossConfig` and `CorrectorRegularizationConfig` so nested-dataclass symbol checks see them.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

@mcgibbon mcgibbon changed the title Feature/corrector regularization Add corrector regularization to training Jun 3, 2026
mcgibbon and others added 4 commits June 3, 2026 15:45
Replace the bare TensorDict return type from StepABC.step (and the
free-standing step_with_adjustments) with a StepResult dataclass. The
new type currently carries only the denormalized output dict and has
no behavioral effect, but provides a structured place to surface
additional per-step information in subsequent commits (e.g. corrector
corrections used for regularization).

All six StepABC implementations now wrap their outputs in StepResult;
multi-call's StepMethod alias is updated and MultiCall.step extracts
.output internally; Stepper.step, predict_generator, and the coupled
predict generator are updated to thread the new type and unpack
.output where downstream code expects a plain dict.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add an optional ``corrections`` field to ``StepResult``. When a
corrector is configured, ``step_with_adjustments`` now records the
per-variable post-corrector minus pre-corrector tensors in
denormalized space and returns them alongside the output;
``MultiCallStep`` forwards corrections from its wrapped step. No
consumer of these corrections yet — this commit exposes the data so
later commits can apply a training-time regularization that penalizes
the corrector's adjustments.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add ``CorrectorRegularizationConfig`` (a ``LossConfig`` plus a scalar
weight, with validation that rejects EnsembleLoss/NaN/global_mean_type
since the comparison is gen-vs-gen). When set on ``TrainStepperConfig``,
``TrainStepper`` builds a ``WeightedMappingLoss`` via the new
``Stepper.build_corrector_regularization_loss`` helper using the loss
normalizer; each optimized forward step adds
``weight * loss(corrections, zeros).total()`` to the accumulated
training loss. Per-step and epoch-aggregated metrics are recorded.

Because the loss normalizer is applied to both predict and target
inside ``WeightedMappingLoss``, the normalizer's mean offset cancels
naturally — corrections in denormalized space and a zero target gives
the correct ``corrections / std`` magnitude.

``LossConfig`` and ``CorrectorRegularizationConfig`` are exported from
``fme.ace`` so nested-dataclass symbol checks see them.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Plumb corrector corrections through the coupled stepper so that per-realm
regularization can be applied during training. ``ComponentStepPrediction``
and ``ComponentEnsembleStepPrediction`` carry an optional corrections
dict; the coupled predict generator forwards each component's corrections
from its underlying ``StepResult``.

Add ``corrector_regularization`` to ``ComponentTrainingConfig`` so the
ocean and atmosphere components can each enable the penalty independently.
``CoupledStepperTrainLoss`` builds a per-realm
``WeightedMappingLoss`` (using each component's loss normalizer) and
exposes ``compute_corrector_regularization`` for use inside the
``CoupledTrainStepper._accumulate_step_loss`` loop, where the term is
accumulated into the optimizer and recorded as
``loss/{realm}_corrector_regularization_step_{i}``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@mcgibbon mcgibbon force-pushed the feature/corrector-regularization branch from f312676 to a6dea01 Compare June 3, 2026 15:50
@jpdunc23

jpdunc23 commented Jun 3, 2026

Copy link
Copy Markdown
Member

What loss / weight do you suggest I try with this feature? Thinking about launching two experiments, one with an L1 penalty and another with an L2 penalty, but lmk if you have a particular config in mind.

@mcgibbon

mcgibbon commented Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

What loss / weight do you suggest I try with this feature? Thinking about launching two experiments, one with an L1 penalty and another with an L2 penalty, but lmk if you have a particular config in mind.

Yeah that matches what I was thinking.

The normalization is the same as for the main loss term, which helps reason about weighting. L2 is probably the better match for the behavior you want, and also allows the model to keep slightly-negative zero-values with little penalty, so I'd focus on that if you want to do a weight sweep. 1.0 is a reasonable starting point for weighting, which means the model puts equal importance to obeying the corrector's goals and to having skill. Even a smaller value like 0.05-0.2 should significantly reduce the divergence between pre-and post-corrector values, because you'll only get divergence to the extent that this helps the model predict better. If you got large divergence even with 0.2 weight, it would imply the divergence is fairly significantly helping the model's skill, which I don't expect.

Maybe try 0.2, 1.0, and 5.0 on L2 as a first pass if you want to sweep, or 1.0 on L2 if you're limiting to one config?

@jpdunc23

jpdunc23 commented Jun 3, 2026

Copy link
Copy Markdown
Member

@mcgibbon please merge #1223 when you can (fixed on my exper branch after failing job)

@jpdunc23

jpdunc23 commented Jun 3, 2026

Copy link
Copy Markdown
Member

#1223)

step_total_loss and reg_loss share the same forward graph. With gradient
accumulation, accumulate_loss backward()s immediately, so two separate
calls freed and re-backwarded the graph -> RuntimeError. Combine into
one accumulate_loss(step_total_loss + reg_loss) call.

Short description of why the PR is needed and how it satisfies those
requirements, in sentence form.

Changes:
- symbol (e.g. `fme.core.my_function`) or script and concise description
of changes or added feature
- Can group multiple related symbols on a single bullet

- [ ] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Resolves #<github issues> (delete if none)

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
@mcgibbon

mcgibbon commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Experiment here: https://wandb.ai/ai2cm/ace-samudra-cm4/runs/xzmbufek

Looks like so far it's improving the thetao evolution as much as the precorrector optimization. To be seen if his continues, or if the regularization approach bottoms out earlier on how much improvement it gives.
image

@jpdunc23

Copy link
Copy Markdown
Member

Could you please add a feature to exclude the penalty for specific variables? Excluding sea ice in pre-corrector optim removed the small positive sea ice biases in mid latitudes, so I think it should do the same here.

Here's an example from my PR of configuring this using the NameAndPrefixMatcher class:

@dataclasses.dataclass
class PreCorrectorOptimizationConfig:
"""
Configuration enabling pre-corrector optimization of the training loss.
When this config is present, the training loss for corrector-modified
variables is computed against the model's pre-correction (uncorrected)
output rather than its corrected output, except for variables matched by
``exclude_names_and_prefixes``, which continue to be optimized against their
corrected (post-adjustment) values. The presence of this object is itself
the on/off switch; there is no separate ``enabled`` flag.
The rollout state and returned predictions always use the fully corrected
outputs regardless of this config; only the training loss target is affected.
Parameters:
exclude_names_and_prefixes: Names and prefixes of variables to exclude
from pre-corrector optimization. Matching follows the same
name-and-prefix convention as spatial masking: a bare name matches
the 2D variable and all of its 3D levels, a trailing-underscore
prefix (e.g. ``thetao_``) matches all levels, and an explicit
``name_<level>`` matches exactly.
"""
exclude_names_and_prefixes: list[str] = dataclasses.field(default_factory=list)
def __post_init__(self):
self._matcher = NameAndPrefixMatcher(self.exclude_names_and_prefixes)
def select_precorrected(self, uncorrected: TensorMapping) -> TensorDict:
"""Return the subset of ``uncorrected`` to override into the loss target.
Excluded variables are dropped so that they retain their corrected value
in the loss; all other corrector-modified variables are returned so they
override their corrected value with their pre-correction value.
"""
return {
name: value
for name, value in uncorrected.items()
if not self._matcher.matches(name)
}
def unmatched_exclusions(self, names: Iterable[str]) -> list[str]:
"""Return exclusion entries that match none of ``names``.
Used for warn-once validation against the actual set of
corrector-modified variables, which is only known at runtime.
"""
names = list(names)
return [
entry
for entry in self.exclude_names_and_prefixes
if not any(NameAndPrefixMatcher([entry]).matches(name) for name in names)
]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants