fix: use forcing data_mask for ensemble training forward pass and loss#1262
fix: use forcing data_mask for ensemble training forward pass and loss#1262yyexela wants to merge 5 commits into
Conversation
When training with n_ensemble > 1, data_mask was read from input_ensemble_data (the IC slice, which carries no mask) instead of forcing_ensemble_data / data. This left NaN forcing values unsanitized before the forward pass, producing NaN loss for any masked ensemble batch. Also removes the Stepper.set_epoch override that called request_latent_global_mean_envelope_reset; the base-class no-op in TrainStepperABC is sufficient now that the latent global-mean removal feature no longer needs per-epoch resets. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
These are still called by the trainer in main and are needed for the clip_latent_global_means feature. The removal belonged to the transformer-backbones branch where trainer.py drops the stepper.set_epoch call; it must not land here independently. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…spection Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… training Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Claude DescriptionRoot cause
Both bugs are the same mistake: using the IC-derived mask where the full mask is needed. Change 1 —
|
Fixes a bug where masked forcing variables containing NaN could produce NaN loss during ensemble training.
When
train_on_batchis called withn_ensemble > 1, thedata_maskcontrolling which forcing values get zeroed before the forward pass was read frominput_ensemble_data(the initial-condition slice, which carries no mask) instead offorcing_ensemble_data. The same incorrect source was used for loss accumulation. This meant any sample with a masked (NaN) forcing variable would propagate NaN through the model and produce NaN loss.Changes:
TrainStepper._run_steps: passdata_mask=forcing_ensemble_data.data_masktopredict_generator(wasinput_ensemble_data.data_mask)TrainStepper._run_steps: passdata_mask=data.data_maskto_accumulate_step_loss(wasinput_batch_data.data_mask)Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated