[WIP] Feature/var masking dropout train stepper#1198
Open
yyexela wants to merge 6 commits into
Open
Conversation
yyexela
commented
May 27, 2026
fbf5b17 to
9e85db4
Compare
10dc4db to
c01ea7b
Compare
Contributor
Author
|
Regarding masking during inference: In The training argument is passed directly from This is also verified by |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds training-time input variable dropout to
SingleModuleStepConfigso models can be trained to be robust to missing input variables.VariableMaskingConfignow supports both per-variable rates with optional prefix-group support for leveled fields and uniform random selection of a variable count per sample. The effective mask merges with any existingdata_maskso both absence sources consistently zero the network input, channel mask indicators, and global-mean-removal extra channels.Configuration options
VariableMaskingConfigControls input dropout with two optional sections:
per_variable(dict[str, float], default{}): Per-variable mask probabilities in[0, 1]. A rate of1.0always masks;0.0explicitly handles the variable without returning a mask for it. Keys ending with_are prefix keys: they match all variables of the form{prefix}{digits}(for example,"air_temperature_"matchesair_temperature_0,air_temperature_1, etc.) and all matched levels share a single random draw per sample.per_variable.default_rate(float, default0.0): Reserved key applied independently to variables not listed by exact or prefix keys.uniform(UniformMaskingConfig | None, defaultNone): Optional uniform count-based masking config withmin_vars,max_vars, andignore_vars.Example: mask
global_mean_co280% of the time, mask all air temperature levels together 20% of the time, and additionally mask between 1 and 3 variables per sample from the remaining uniform pool while never selecting SST in that uniform pool:UniformMaskingConfigNested under
VariableMaskingConfig.uniform. It masks a uniformly sampled count of variables per sample:min_vars(int or"min", default"min"): Minimum number of variables to mask per sample."min"resolves to0.max_vars(int or"max", default"max"): Maximum number of variables to mask per sample."max"resolves to the number of eligible variables.ignore_vars(list[str], default[]): Variables excluded from uniform masking only. Explicitper_variablerates still apply to these variables.Each sample independently draws a random integer
nfrom[min_vars, max_vars]and then selectsnvariables uniformly at random from the eligible uniform pool.Combined behavior
When both sections are configured:
per_variablegroups are sampled first.0.0.default_rateapplies independently to unmatched variables.default_rateare excluded from the uniform pool only whendefault_rate > 0.SingleModuleStepConfig.input_dropoutSet to
VariableMaskingConfigornullto disable. Dropout is only active when the underlying PyTorch module is in training mode; it is automatically disabled during inference.Effect on training
At each training step:
True= present,False= masked).data_maskalready present in the batch, so training-time dropout and dataset-level absence are handled uniformly.include_channel_mask_inputs=True, the indicator channels (0.0= absent,1.0= present) also reflect the combined mask, giving the model explicit information about which inputs are unavailable.global_mean_removalis configured withappend_as_input=True, the extra global-mean channels for masked variables are also zeroed, preventing the network from inferring a dropped variable's value through those channels.Changes:
fme.core.var_masking.VariableMaskingConfig: merged per-variable and uniform count-based masking into one config withper_variableand optional nesteduniform.fme.core.var_masking.UniformMaskingConfig: replaces the previous separate uniform masking config as the nested config object forVariableMaskingConfig.uniform.fme.core.step.single_module.SingleModuleStepConfig.input_dropout: new optionalVariableMaskingConfig | Nonefield; during training, samples masks and ANDs them withargs.data_maskbefore the network forward pass.fme.core.step.single_module._apply_extra_channel_mask: zeros global-mean-removal extra channels for masked variables so the network cannot infer a dropped field's mean.fme.core.step.global_mean_removal.GlobalMeanRemoval.extra_channel_names: abstract property implemented onNoGlobalMeanRemoval,SharedGlobalMeanRemoval, andPerChannelGlobalMeanRemoval; used by_apply_extra_channel_maskto identify which extra channels to zero.fme.ace.__init__: exportsVariableMaskingConfigandUniformMaskingConfig.Tests added
No dependencies changed; deps-only image rebuild and
latest_deps_only_image.txtupdate not needed.