Add corrector regularization to training#1218
Conversation
Replace the bare TensorDict return type from StepABC.step (and the free-standing step_with_adjustments) with a StepResult dataclass. The new type currently carries only the denormalized output dict and has no behavioral effect, but provides a structured place to surface additional per-step information in subsequent commits (e.g. corrector corrections used for regularization). All six StepABC implementations now wrap their outputs in StepResult; multi-call's StepMethod alias is updated and MultiCall.step extracts .output internally; Stepper.step, predict_generator, and the coupled predict generator are updated to thread the new type and unpack .output where downstream code expects a plain dict. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add an optional ``corrections`` field to ``StepResult``. When a corrector is configured, ``step_with_adjustments`` now records the per-variable post-corrector minus pre-corrector tensors in denormalized space and returns them alongside the output; ``MultiCallStep`` forwards corrections from its wrapped step. No consumer of these corrections yet — this commit exposes the data so later commits can apply a training-time regularization that penalizes the corrector's adjustments. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add ``CorrectorRegularizationConfig`` (a ``LossConfig`` plus a scalar weight, with validation that rejects EnsembleLoss/NaN/global_mean_type since the comparison is gen-vs-gen). When set on ``TrainStepperConfig``, ``TrainStepper`` builds a ``WeightedMappingLoss`` via the new ``Stepper.build_corrector_regularization_loss`` helper using the loss normalizer; each optimized forward step adds ``weight * loss(corrections, zeros).total()`` to the accumulated training loss. Per-step and epoch-aggregated metrics are recorded. Because the loss normalizer is applied to both predict and target inside ``WeightedMappingLoss``, the normalizer's mean offset cancels naturally — corrections in denormalized space and a zero target gives the correct ``corrections / std`` magnitude. ``LossConfig`` and ``CorrectorRegularizationConfig`` are exported from ``fme.ace`` so nested-dataclass symbol checks see them. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Plumb corrector corrections through the coupled stepper so that per-realm
regularization can be applied during training. ``ComponentStepPrediction``
and ``ComponentEnsembleStepPrediction`` carry an optional corrections
dict; the coupled predict generator forwards each component's corrections
from its underlying ``StepResult``.
Add ``corrector_regularization`` to ``ComponentTrainingConfig`` so the
ocean and atmosphere components can each enable the penalty independently.
``CoupledStepperTrainLoss`` builds a per-realm
``WeightedMappingLoss`` (using each component's loss normalizer) and
exposes ``compute_corrector_regularization`` for use inside the
``CoupledTrainStepper._accumulate_step_loss`` loop, where the term is
accumulated into the optimizer and recorded as
``loss/{realm}_corrector_regularization_step_{i}``.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
f312676 to
a6dea01
Compare
|
What loss / weight do you suggest I try with this feature? Thinking about launching two experiments, one with an L1 penalty and another with an L2 penalty, but lmk if you have a particular config in mind. |
Yeah that matches what I was thinking. The normalization is the same as for the main loss term, which helps reason about weighting. L2 is probably the better match for the behavior you want, and also allows the model to keep slightly-negative zero-values with little penalty, so I'd focus on that if you want to do a weight sweep. 1.0 is a reasonable starting point for weighting, which means the model puts equal importance to obeying the corrector's goals and to having skill. Even a smaller value like 0.05-0.2 should significantly reduce the divergence between pre-and post-corrector values, because you'll only get divergence to the extent that this helps the model predict better. If you got large divergence even with 0.2 weight, it would imply the divergence is fairly significantly helping the model's skill, which I don't expect. Maybe try 0.2, 1.0, and 5.0 on L2 as a first pass if you want to sweep, or 1.0 on L2 if you're limiting to one config? |
|
@mcgibbon please merge #1223 when you can (fixed on my exper branch after failing job) |
|
Experiment here: https://wandb.ai/ai2cm/ace-samudra-cm4/runs/xzmbufek |
#1223) step_total_loss and reg_loss share the same forward graph. With gradient accumulation, accumulate_loss backward()s immediately, so two separate calls freed and re-backwarded the graph -> RuntimeError. Combine into one accumulate_loss(step_total_loss + reg_loss) call. Short description of why the PR is needed and how it satisfies those requirements, in sentence form. Changes: - symbol (e.g. `fme.core.my_function`) or script and concise description of changes or added feature - Can group multiple related symbols on a single bullet - [ ] Tests added - [ ] If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated Resolves #<github issues> (delete if none) Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Looks like so far it's improving the thetao evolution as much as the precorrector optimization. To be seen if his continues, or if the regularization approach bottoms out earlier on how much improvement it gives. |
|
Could you please add a feature to exclude the penalty for specific variables? Excluding sea ice in pre-corrector optim removed the small positive sea ice biases in mid latitudes, so I think it should do the same here. Here's an example from my PR of configuring this using the ace/fme/ace/stepper/single_module.py Lines 1462 to 1515 in 046a3f8 |

Add an optional training-time penalty on the magnitude of corrector adjustments, enabling L1/L2 regularization that discourages the model from relying on the corrector. Supported for both single-module and coupled training; when configured, the specified loss is computed between each corrector's per-variable correction (post-corrector minus pre-corrector) and a zero baseline in the loss-normalized space, then accumulated into the training loss.
Changes:
fme.core.step.step.StepResult: new dataclass replacing the bareTensorDictreturn type ofStepABC.step; carries the denormalized output plus optional per-variable corrections. AllStepABCimplementations and consumers updated.fme.core.step.single_module.step_with_adjustments: recordscorrections = post − prewhen a corrector runs;MultiCallStepforwards corrections from its wrapped step.fme.core.loss.CorrectorRegularizationConfig: newLossConfig+weightwrapper, with validation rejectingEnsembleLoss/NaN/global_mean_type.fme.ace.stepper.single_module.TrainStepperConfig.corrector_regularization: optional field; built viaStepper.build_corrector_regularization_lossand applied per optimized step inTrainStepper._accumulate_loss. Per-step (`corrector_regularization_step_{i}`) and epoch-aggregated (`corrector_regularization`) metrics.fme.coupled.stepper.ComponentTrainingConfig.corrector_regularization: per-realm version. `CoupledStepperTrainLoss.compute_corrector_regularization` wired into `CoupledTrainStepper._accumulate_step_loss`; emits `loss/{realm}corrector_regularization_step{i}`.`fme.ace` re-exports `LossConfig` and `CorrectorRegularizationConfig` so nested-dataclass symbol checks see them.
Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated