Skip to content

[WIP] Feature/var masking dropout train stepper#1198

Open
yyexela wants to merge 6 commits into
mainfrom
feature/var_masking_dropout_train_stepper
Open

[WIP] Feature/var masking dropout train stepper#1198
yyexela wants to merge 6 commits into
mainfrom
feature/var_masking_dropout_train_stepper

Conversation

@yyexela

@yyexela yyexela commented May 27, 2026

Copy link
Copy Markdown
Contributor

Adds training-time input variable dropout to SingleModuleStepConfig so models can be trained to be robust to missing input variables. VariableMaskingConfig now supports both per-variable rates with optional prefix-group support for leveled fields and uniform random selection of a variable count per sample. The effective mask merges with any existing data_mask so both absence sources consistently zero the network input, channel mask indicators, and global-mean-removal extra channels.

Configuration options

VariableMaskingConfig

Controls input dropout with two optional sections:

  • per_variable (dict[str, float], default {}): Per-variable mask probabilities in [0, 1]. A rate of 1.0 always masks; 0.0 explicitly handles the variable without returning a mask for it. Keys ending with _ are prefix keys: they match all variables of the form {prefix}{digits} (for example, "air_temperature_" matches air_temperature_0, air_temperature_1, etc.) and all matched levels share a single random draw per sample.
  • per_variable.default_rate (float, default 0.0): Reserved key applied independently to variables not listed by exact or prefix keys.
  • uniform (UniformMaskingConfig | None, default None): Optional uniform count-based masking config with min_vars, max_vars, and ignore_vars.

Example: mask global_mean_co2 80% of the time, mask all air temperature levels together 20% of the time, and additionally mask between 1 and 3 variables per sample from the remaining uniform pool while never selecting SST in that uniform pool:

input_dropout:
  uniform:
    min_vars: 1
    max_vars: 3
    ignore_vars: [SST]
  per_variable:
    global_mean_co2: 0.8
    air_temperature_: 0.2
    default_rate: 0.0

UniformMaskingConfig

Nested under VariableMaskingConfig.uniform. It masks a uniformly sampled count of variables per sample:

  • min_vars (int or "min", default "min"): Minimum number of variables to mask per sample. "min" resolves to 0.
  • max_vars (int or "max", default "max"): Maximum number of variables to mask per sample. "max" resolves to the number of eligible variables.
  • ignore_vars (list[str], default []): Variables excluded from uniform masking only. Explicit per_variable rates still apply to these variables.

Each sample independently draws a random integer n from [min_vars, max_vars] and then selects n variables uniformly at random from the eligible uniform pool.

Combined behavior

When both sections are configured:

  1. Exact and prefix per_variable groups are sampled first.
  2. Variables handled by exact or prefix entries are excluded from the uniform pool, even when their configured rate is 0.0.
  3. default_rate applies independently to unmatched variables.
  4. Variables handled only by default_rate are excluded from the uniform pool only when default_rate > 0.
  5. Any overlapping masks are combined with logical AND.

SingleModuleStepConfig.input_dropout

Set to VariableMaskingConfig or null to disable. Dropout is only active when the underlying PyTorch module is in training mode; it is automatically disabled during inference.

Effect on training

At each training step:

  1. The dropout config samples a per-sample boolean mask (True = present, False = masked).
  2. This mask is ANDed with any data_mask already present in the batch, so training-time dropout and dataset-level absence are handled uniformly.
  3. Masked variables are zeroed in normalized space before the network forward pass.
  4. If include_channel_mask_inputs=True, the indicator channels (0.0 = absent, 1.0 = present) also reflect the combined mask, giving the model explicit information about which inputs are unavailable.
  5. If global_mean_removal is configured with append_as_input=True, the extra global-mean channels for masked variables are also zeroed, preventing the network from inferring a dropped variable's value through those channels.

Changes:

  • fme.core.var_masking.VariableMaskingConfig: merged per-variable and uniform count-based masking into one config with per_variable and optional nested uniform.

  • fme.core.var_masking.UniformMaskingConfig: replaces the previous separate uniform masking config as the nested config object for VariableMaskingConfig.uniform.

  • fme.core.step.single_module.SingleModuleStepConfig.input_dropout: new optional VariableMaskingConfig | None field; during training, samples masks and ANDs them with args.data_mask before the network forward pass.

  • fme.core.step.single_module._apply_extra_channel_mask: zeros global-mean-removal extra channels for masked variables so the network cannot infer a dropped field's mean.

  • fme.core.step.global_mean_removal.GlobalMeanRemoval.extra_channel_names: abstract property implemented on NoGlobalMeanRemoval, SharedGlobalMeanRemoval, and PerChannelGlobalMeanRemoval; used by _apply_extra_channel_mask to identify which extra channels to zero.

  • fme.ace.__init__: exports VariableMaskingConfig and UniformMaskingConfig.

  • Tests added

  • No dependencies changed; deps-only image rebuild and latest_deps_only_image.txt update not needed.

Comment thread fme/core/step/single_module.py Outdated
@yyexela yyexela force-pushed the feature/var_masking_dropout_train_stepper branch 2 times, most recently from fbf5b17 to 9e85db4 Compare May 28, 2026 15:39
@yyexela yyexela force-pushed the feature/var_masking_dropout_train_stepper branch from 10dc4db to c01ea7b Compare June 4, 2026 23:40
@yyexela

yyexela commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Regarding masking during inference:

In _build_effective_input_mask (fme/core/step/single_module.py:496):

result = dict(data_mask) if data_mask is not None else {}
if input_dropout is None or not training:
    return result

The training argument is passed directly from self.module.torch_module.training (line 404), which PyTorch sets to False when the module is in eval mode. When training is False, the function returns before calling input_dropout.sample_masks, so no dropout masks are ever generated. Only the dataset-level data_mask (if any) is applied.

This is also verified by test_input_dropout_applied_in_train_mode_not_eval (fme/core/step/test_step.py:1316), which runs the same step in both modes with rate=1.0 and
asserts the indicator channel is 0.0 (masked) in train mode and 1.0 (present) in eval mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant