Skip to content

fix: use forcing data_mask for ensemble training forward pass and loss#1262

Open
yyexela wants to merge 5 commits into
mainfrom
fix/ensemble-data-mask
Open

fix: use forcing data_mask for ensemble training forward pass and loss#1262
yyexela wants to merge 5 commits into
mainfrom
fix/ensemble-data-mask

Conversation

@yyexela

@yyexela yyexela commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Fixes a bug where masked forcing variables containing NaN could produce NaN loss during ensemble training.

When train_on_batch is called with n_ensemble > 1, the data_mask controlling which forcing values get zeroed before the forward pass was read from input_ensemble_data (the initial-condition slice, which carries no mask) instead of forcing_ensemble_data. The same incorrect source was used for loss accumulation. This meant any sample with a masked (NaN) forcing variable would propagate NaN through the model and produce NaN loss.

Changes:

  • TrainStepper._run_steps: pass data_mask=forcing_ensemble_data.data_mask to predict_generator (was input_ensemble_data.data_mask)

  • TrainStepper._run_steps: pass data_mask=data.data_mask to _accumulate_step_loss (was input_batch_data.data_mask)

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

yyexela and others added 5 commits June 11, 2026 15:33
When training with n_ensemble > 1, data_mask was read from
input_ensemble_data (the IC slice, which carries no mask) instead of
forcing_ensemble_data / data. This left NaN forcing values unsanitized
before the forward pass, producing NaN loss for any masked ensemble
batch.

Also removes the Stepper.set_epoch override that called
request_latent_global_mean_envelope_reset; the base-class no-op in
TrainStepperABC is sufficient now that the latent global-mean removal
feature no longer needs per-epoch resets.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
These are still called by the trainer in main and are needed for the
clip_latent_global_means feature. The removal belonged to the
transformer-backbones branch where trainer.py drops the stepper.set_epoch
call; it must not land here independently.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…spection

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… training

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@yyexela yyexela requested a review from mcgibbon June 11, 2026 22:57
@yyexela

yyexela commented Jun 11, 2026

Copy link
Copy Markdown
Contributor Author

Claude Description

Root cause

get_start calls subset_names(prognostic_names), which strips data_mask down to only prognostic variable keys. So:

  • input_batch_data.data_mask / input_ensemble_data.data_maskprognostic vars only
  • data.data_mask / forcing_ensemble_data.data_maskall vars (prognostic + forcing + output-only)

Both bugs are the same mistake: using the IC-derived mask where the full mask is needed.


Change 1 — predict_generator: zero out NaN forcing before the forward pass

predict_generator uses data_mask to zero out masked forcing values before they enter the model. With input_ensemble_data.data_mask, forcing-only variables like b have no mask entry (stripped by subset_names), so their NaN values are never zeroed and propagate through to produce NaN loss.

Fix: use forcing_ensemble_data.data_mask, which comes from data directly with no subsetting.


Change 2 — _accumulate_step_loss: exclude masked samples from loss count

_accumulate_step_loss uses data_mask to exclude masked samples from the loss denominator. With input_batch_data.data_mask, output-only variables have no mask entry, so all their samples are counted — even ones that should be excluded.

Fix: use data.data_mask, which has mask entries for all variables.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant