Add combined infill-prediction multi-task training#1224
Conversation
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Key changes based on review: - TrainingTask now has separate previous_step_input_names and current_step_input_names instead of a single list - Added forcing_names to distinguish input-only variables that should never be prediction targets - Disabled residual prediction (omitted from config) - Added multi-step rollout support: normal forward prediction for steps 0..N-2, task applied on the final step only - Added per-task loss scaling via TaskWeights dataclass - Extend existing TrainStepper rather than creating a new type, with clean branch in _accumulate_loss - Resolved open questions on corrector/ocean (apply if configured), data requirements (always use N timesteps) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- InfillPredictionStep handles missing variables the same way as SingleModuleStep: caller provides all variables, step uses data_mask - TaskSamplingConfig lives exclusively in TrainStepperConfig, not in the step config (training concern, not inference) - TaskSampler.sample() accepts data_mask to exclude variables that are missing from the entire batch, preventing degenerate tasks - Remove task_sampling from step config YAML and clean up notes Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Breaks the design into 3 independently reviewable commits: 1. Config dataclasses + TaskSampler 2. InfillPredictionStepConfig + InfillPredictionStep 3. TrainStepper integration with two-phase training loop Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Key changes: - Multi-task per update: prediction loss for all N steps + task loss from one additional forward pass (instead of replacing final step) - Per-sample task sampling: each sample independently gets a task type and variable assignments, respecting per-sample data_mask - SampledTasks with per-sample masks replaces single TrainingTask - Auto-encode outputs = exactly the selected inputs (not a subset) - Output channels = non-forcing variables only (forcing never predicted) - output_data_mask uses float loss_scale values for natural weighted averaging in existing loss infrastructure Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Specifies how input/output variables are sampled per task type: - Disjoint tasks (infill, infill_prediction): sample forcing inputs independently, then split the non-forcing "contested pool" symmetrically using total-then-split to avoid biasing n_in vs n_out - Non-disjoint tasks (prediction, combined_all): sample n_in and n_out independently since they don't compete - Auto-encode: single count, variables are both inputs and outputs - min_in_contested = max(0, min_input - n_forcing_in) allows all-forcing input cases without requiring it Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Make variable assignment explicit: uniform random subsets via random.sample (every variable in a pool has equal selection probability) - Expand test section into three categories: constraint validation (deterministic), config/error handling, and statistical tests for the variable selection algorithm (symmetry, uniformity, full coverage) - Address user feedback on test extensiveness and assignment mechanism Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Introduces TaskWeights, TaskSamplingConfig, InferenceSchemeConfig, SampledTasks, and TaskSampler with per-sample task assignment. The sampler supports five task types (auto-encode, infill, prediction, infill-prediction, combined-all) with symmetric variable selection for disjoint tasks and data_mask-aware variable filtering. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…g_names Adds InfillPredictionStepConfig (registered as "infill_prediction") and InfillPredictionStep which use all_names for training but inference_scheme for inference. The step fills missing variables with zeros and builds appropriate data_mask when receiving partial inputs. Also adds all_training_names property to StepConfigABC (default None) and updates StepperConfig.all_names to prefer it when available, so the data loader fetches all training variables. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds task_sampling field to TrainStepperConfig. When set, TrainStepper runs an additional forward pass after the prediction loop with task-sampled variable masking and accumulates the task loss alongside the prediction loss. Validates that task_sampling is only used with InfillPredictionStepConfig. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…l-prediction # Conflicts: # fme/ace/stepper/single_module.py
The test_variable_not_in_data_mask_treated_as_available assertion had `or True` which made it always pass. Replaced with a multi-sample check that verifies the variable actually appears. Moved task sampling imports from inline (bottom of test_single_module.py) to the module-level import block at the top. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1. _run_task_step now correctly handles n_ensemble > 1 by repeat-interleaving per-IC task masks and ground truth to match the ensemble-folded batch dimension, and unfolding the task output for ensemble-aware loss computation. 2. Build forcing_at_step from step_config.forcing_names instead of _input_only_names so forcing variables in all_names but not in inference_scheme.in_names are correctly sourced from data. 3. Remove infill-prediction-design.md and infill-prediction-implementation-plan.md from the repo. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Design document for this feature (removed from the repo in the final commit): |
Add TaskSamplingConfig, TaskWeights, InfillPredictionStepConfig, and InferenceSchemeConfig to fme.ace.__init__ so the nested-dataclass symbol test passes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Example configurationBelow is an example YAML snippet showing how to configure multi-task training. The new configuration lives in two places: the step type ( stepper:
step:
type: infill_prediction
config:
builder:
type: SphericalFourierNeuralOperatorNet
config:
scale_factor: 1
embed_dim: 256
num_layers: 8
# All variables the model can see during training.
# Drives data loading — every name here must be in the dataset.
all_names:
- air_temperature
- specific_humidity
- surface_pressure
- precipitation_rate # non-forcing, training-only (not in inference scheme)
- insolation # forcing
- land_sea_mask # forcing
# Forcing variables are input-only: never predicted by the model.
forcing_names:
- insolation
- land_sea_mask
normalization:
network:
means: {air_temperature: 250.0, specific_humidity: 0.005, ...}
stds: {air_temperature: 30.0, specific_humidity: 0.004, ...}
# Inference scheme defines the variable routing at inference time.
# Must be a subset of all_names.
inference_scheme:
in_names:
- air_temperature
- specific_humidity
- surface_pressure
- insolation
- land_sea_mask
out_names:
- air_temperature
- specific_humidity
- surface_pressure
next_step_forcing_names: []
prescribed_prognostic_names: []
# Must be True (the network receives 2x channels: values + mask indicators).
include_channel_mask_inputs: true
# Optional, same as SingleModuleStepConfig:
# ocean: ...
# corrector: ...
# secondary_decoder: ...
stepper_training:
loss:
type: MSE
task_sampling:
tasks:
# Each task has a probability (relative sampling weight, normalized to
# sum to 1) and loss_scaling (multiplier on the output mask).
# Set probability to 0 to disable a task.
auto_encode:
probability: 1.0
loss_scaling: 1.0
infill:
probability: 1.0
loss_scaling: 1.0
prediction:
probability: 1.0
loss_scaling: 1.0
infill_prediction:
probability: 1.0
loss_scaling: 1.0
combined_all:
probability: 1.0
loss_scaling: 1.0
# Minimum number of input/output variables selected per sample.
# Must be >= 1.
min_input_variables: 1
min_output_variables: 1What you can configure
What you can't configure
|
Move sampled task masks to model device before use, fixing RuntimeError when training with GPU. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
This file should probably be named something more specific to the contents of the file than the purpose of our experiments right now, like this seems more related to "tasks" than infill and prediction specifically.
There was a problem hiding this comment.
Renamed to task_step.py (and the test to test_task_step.py).
| auto_encode: float = 1.0 | ||
| infill: float = 1.0 | ||
| prediction: float = 1.0 | ||
| infill_prediction: float = 1.0 | ||
| combined_all: float = 1.0 | ||
|
|
||
| auto_encode_loss_scale: float = 1.0 | ||
| infill_loss_scale: float = 1.0 | ||
| prediction_loss_scale: float = 1.0 | ||
| infill_prediction_loss_scale: float = 1.0 | ||
| combined_all_loss_scale: float = 1.0 |
There was a problem hiding this comment.
Instead of having two properties for each task, can we have a sub-configuration for each task that contains these two properties? Something like "probability" and "loss_scaling"?
There was a problem hiding this comment.
Restructured to use a TaskConfig sub-dataclass with probability and loss_scaling fields per task.
| min_output_variables: Minimum number of output variables to select. | ||
| """ | ||
|
|
||
| task_weights: TaskWeights = dataclasses.field(default_factory=TaskWeights) |
There was a problem hiding this comment.
This has weights and probabilities, and controls what tasks we're using, so let's just call it "tasks:".
| """Defines how the model behaves at inference time. | ||
|
|
||
| This mirrors the variable routing of SingleModuleStepConfig for | ||
| standard forward prediction. Future schemes could add post-hoc |
There was a problem hiding this comment.
nit: Let's not discuss future plans in a docstring.
| """Run the additional task step after the prediction loop.""" | ||
| assert self._task_sampler is not None | ||
| assert self._task_loss_obj is not None | ||
| step_config = cast(InfillPredictionStepConfig, self._stepper._step_obj.config) |
There was a problem hiding this comment.
Is this an enforced assumption? if not, we should raise when it's False.
Ideally we should raise at configuration validation time, if this is required.
There was a problem hiding this comment.
The validation already happens at TrainStepper.__init__ time (raises ValueError if step config isn't InfillPredictionStepConfig). Replaced the cast with a stored typed attribute (self._task_step_config) set during init, so there's no runtime cast needed.
There was a problem hiding this comment.
Great. Is there also a __post_init__ configuration check? When we run fme.ace.validate_config, I want an exception raised if it's invalid.
There was a problem hiding this comment.
Pre-review agent: Addressed in 7494484. Added TaskSamplingConfig.validate_for_step(accepted_input_names, loss_names) which is called from TrainConfig.__post_init__, so fme.ace.validate_config will raise at config-validation time. Additionally, this commit generalizes task sampling to work with any step type by using accepted_input_names/loss_names from StepConfigABC instead of requiring InfillPredictionStepConfig.
| torch.testing.assert_close(output.data["a"][:, 0], ic_a) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- |
There was a problem hiding this comment.
nit: Remove comment about how this was committed.
- Rename infill_prediction.py to task_step.py for generality - Restructure TaskWeights: use per-task TaskConfig with probability and loss_scaling fields instead of flat attributes - Rename TaskSamplingConfig.task_weights to tasks - Remove future-plan language from InferenceSchemeConfig docstring - Store typed InfillPredictionStepConfig at init to avoid cast at runtime - Remove "Commit 3" comment from test file Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
I'm working on updating this so that it works seamlessly with all Step types so long as their input/output names are appropriate to the task. The InfillPredictionStep will still be needed for most tasks since we don't have other steps that accept output-step values as inputs, but we could use e.g. a task configuration using only prediction along with the existing SingleModuleStep. |
Add accepted_input_names and accepted_next_step_input_names properties to StepConfigABC so task sampling derives variable pools from generic step properties instead of requiring InfillPredictionStepConfig. Validate task feasibility at config time via TaskSamplingConfig.validate_for_step called from TrainConfig.__post_init__. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
This PR should add test_train.py-style tests that cover at least SingleModuleStep in prediction-only mode, and the infill step in a more complex task mode. It would be nice if they can run as fast as possible. |
There was a problem hiding this comment.
OK sorry my last comment was not fully right. The infill prediction step should exist in its own file, appropriately named. The task sampling logic probably doesn't belong in fme/core/step at all. It might even be best suited for fme/ace/stepper/task.py.
There was a problem hiding this comment.
Claude: Done. Moved task sampling logic (TaskConfig, TaskWeights, TaskSamplingConfig, TaskSampler, SampledTasks, TaskType) to fme/ace/stepper/task.py. InfillPredictionStep and InfillPredictionStepConfig stay in fme/core/step/infill_prediction.py. Tests moved to fme/ace/stepper/test_task.py.
| optimization=optimization, | ||
| ) | ||
| optimization.accumulate_loss(task_loss) | ||
| metrics["loss_task"] = task_loss.detach() |
There was a problem hiding this comment.
I think we don't want the metric name to change when we're doing task-based training. Use the existing value (is it batch_loss? or just loss at this level?).
There was a problem hiding this comment.
Claude: Removed the separate loss_task metric. The task loss is already accumulated into the optimizer before metrics["loss"] is recorded, so it's included in the total loss metric.
| step_config = stepper._step_obj.config | ||
| accepted = step_config.accepted_input_names | ||
| loss = step_config.loss_names | ||
| self._task_sampler = TaskSampler(config.task_sampling, accepted, loss) |
There was a problem hiding this comment.
Question: Doesn't the task sampler need to know about input names that are accepted at the first and at the second step, separately?
There was a problem hiding this comment.
Claude: The task sampler's "previous step" and "current step" masks both control what goes into args.input — they differ in data source (model output vs ground truth), not which step input slot they target. The next_step_input_data is built independently from forcing and is not task-sampled.
So accepted_input_names is the right single pool: it covers everything the step can receive as args.input, regardless of whether the data came from the previous step's output or current-step ground truth.
For InfillPredictionStep, accepted_input_names = all_names (broader than inference input_names). For SingleModuleStep, accepted_input_names = input_names (same as inference). In both cases, the same set applies to both previous-step and current-step sourcing.
| self._task_sampler = TaskSampler(config.task_sampling, accepted, loss) | ||
| task_loss_normalizer = stepper._step_obj.get_loss_normalizer() | ||
| non_forcing = [n for n in accepted if n in set(loss)] | ||
| self._task_loss_obj = config.loss.build( |
There was a problem hiding this comment.
Question: Why is this pattern different from above, self._loss_obj = self._stepper.build_loss(config.loss)?
There was a problem hiding this comment.
Claude: Good catch — they should use the same pattern. Refactored to add an out_names parameter to Stepper.build_loss(), so both losses are now built through the same method:
self._loss_obj = self._stepper.build_loss(config.loss)
# ...
self._task_loss_obj = self._stepper.build_loss(config.loss, out_names=non_forcing)The only difference is out_names: the regular loss uses loss_names (the default), while the task loss uses the non-forcing subset of accepted input names.
… tests - Move task sampling logic to fme/ace/stepper/task.py; keep InfillPredictionStep in fme/core/step/infill_prediction.py - Remove separate loss_task metric (task loss is already in total loss) - Use Stepper.build_loss for both normal and task losses - Add test_train.py integration tests for SingleModuleStep prediction-only and InfillPredictionStep mixed-task modes Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Claude: Added two integration tests in
Both run in ~1.5s each (3s total) using minimal data and a tiny SFNO model. |
|
One thing that I think would make the config easier to parse:
|
|
Otherwise I think this looks good, but it's a lot to add to main at once, might be good to run some tests on each task to see what is most useful |
Yeah, perhaps @Arcomano1234 can give you some multi-task training assignments. For now I'll work on another PR to merge just the non-task portions of this, to reduce the chance of merge conflicts. FWIW because we don't support training backwards compatibility, it's easier to remove or modify tasks after the fact if needed, compared to removing anything that's on the module config. |
|
I drafted a PR that attempts to split part of this out, but it doesn't seem to make much sense on its own. Closing it. |
…l-prediction # Conflicts: # fme/ace/stepper/single_module.py # fme/core/step/step.py
Adds a new
InfillPredictionStepthat trains on multiple tasks per batch (auto-encoding, infill, prediction, infill-prediction, combined-all) by randomly masking input/output variables. At inference time it behaves likeSingleModuleStep. Per gradient update, the normal prediction loss for all forward steps is computed, then one additional forward pass with task masking adds a task loss.Task sampling is generic and works with any step type — variable pools are derived from
StepConfigABC.accepted_input_namesandloss_names, with config-time validation that raises if enabled tasks are infeasible for the step's variable routing.Changes:
fme.core.step.task_step: NewTaskConfig,TaskWeights,TaskSamplingConfig,TaskSampler,SampledTasks,InferenceSchemeConfig,InfillPredictionStepConfig,InfillPredictionStepTaskSamplersamples per-batch-element task types and variable assignments respecting disjointness and minimum-variable constraintsInfillPredictionStepusesall_namesfor input channels with channel mask indicators andnon_forcing_namesfor output channels; fills missing variables with zeros at inferencefme.core.step.step.StepConfigABC: Addaccepted_input_names,accepted_next_step_input_names, andall_training_namesproperties so step configs can declare training-time variable capabilities distinct from inference-time routingfme.ace.stepper.single_module.StepperConfig.all_names: Checkall_training_namesfirst, fall back toset(input_names + output_names)fme.ace.stepper.single_module.TrainStepperConfig: Add optionaltask_sampling: TaskSamplingConfigfieldfme.ace.stepper.single_module.TrainStepper: Whentask_samplingis set, buildTaskSamplerfrom genericaccepted_input_names/loss_names; after the prediction loop, run one additional task step and accumulate its lossn_ensemble > 1by repeat-interleaving per-IC masks/ground-truth across ensemble membersinput_only_names(accepted inputs minus loss names)fme.ace.train.train_config.TrainConfig.__post_init__: Config-time validation viaTaskSamplingConfig.validate_for_stepfme.ace.stepper.loss_schedule.LossSchedule: Addis_trainingpropertyTests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated