Fix masked variable handling#1174
Open
yyexela wants to merge 19 commits into
Open
Conversation
Outlines the plan for per-sample variable masking in ACE training, inference, and data loading, enabling training on heterogeneous data sources where some datasets are missing variables. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add a `data_mask: TensorMapping | None` field to BatchData for per-sample, per-variable masking. Propagate it through all methods that create new BatchData instances (to_device, to_cpu, broadcast_ensemble, subset_names, etc). Add validation and tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Wire the flag through SingleModuleStepConfig -> StepConfigABC -> StepSelector -> StepperConfig -> DataRequirements. The flag is a safety gate: when False (default), missing required variables cause an error. When True, the data pipeline may omit variables and provide masks instead. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Thread data_mask from BatchData through predict_generator and _accumulate_loss into StepArgs, so it is available to step implementations for input and output masking. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Zero out masked input variables in normalized space before the network call (_apply_input_mask) and zero out both predictions and targets for masked output variables before loss computation (_apply_output_mask). This ensures masked variables don't influence training gradients. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When allow_variable_masking is True, XarrayDataset skips variables not found in the data files instead of raising an error. CollateFn handles heterogeneous variable sets across samples by taking the union of names, filling missing variables with zeros, and producing a data_mask. The allow_variable_masking flag is threaded from DataRequirements through get_gridded_data -> DataLoaderConfig -> dataset build chain -> XarrayDataset, and also to CollateFn -> BatchData.from_sample_tuples. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pass allow_variable_masking from DataRequirements through to InferenceDataset, which forwards it to XarrayDataset construction and BatchData.from_sample_tuples collation. Also update get_forcing_data to pass the flag when constructing XarrayDataset for time index discovery. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…s chain Fill missing variables with NaN instead of zeros during collation so that the existing NaN masking in WeightedMappingLoss serves as a safety net. Propagate data_mask through StepLoss -> WeightedMappingLoss -> LossOutput so that total() uses a masked mean, preventing absent variables from diluting the optimization target. Guard _get_variable_metadata with allow_variable_masking to raise ValueError for missing variables when masking is disabled. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Use per-sample data_mask (pre-broadcast) instead of the ensemble-broadcast mask for _apply_output_mask and StepLoss, since gen_step/target_step have been unfolded back to [batch, ensemble, ...]. The ensemble-broadcast mask had shape [batch*n_ensemble] which doesn't broadcast against the unfolded tensors when n_ensemble > 1. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ve tests - In _collate_with_masking, cast non-floating-point dtypes to float32 before filling with NaN to avoid RuntimeError on integer tensors. - Delete masked-variables-design.md from repo root (content belongs in the PR description). - Add test for integer dtype handling in collation. - Strengthen concat masking test to verify data_mask contents. - Clarify inference masking test expectations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When enabled, appends per-variable mask indicator channels (1.0 = present, 0.0 = masked) to the packed network input, doubling the input channel count. This lets the model distinguish between physically zero inputs and masked (zeroed) inputs. The mask channels are built separately from the data packer and concatenated along the channel dim after packing, keeping the packer clean and avoiding normalization of the mask values. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Added information about expected loss_value and mask shapes in _reduce_to_per_channel docstring, making it more clear the expected input sizes. Also modified test_step.py asserts to be more readable.
…asking=True Previously, self._names was initialized from the caller-supplied names list and then manually reconciled after _group_variable_names_by_time_type() to drop variables missing from the dataset. This left a window where self._names and the three partitioned lists were out of sync, and required defensive guards in _get_variable_metadata to handle names that had already been silently skipped. Replace self._names with a property derived from the three discovered lists (_time_dependent_names, _time_invariant_names, _static_derived_names), which are the single source of truth after discovery. The original requested names are stored in self._requested_names for use only inside _group_variable_names_by_time_type. Move the _get_variable_metadata call to after discovery so it iterates only over found variables, allowing removal of the missing-variable guards there. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ames The added unit test demonstrates that when a requested variable name is not present in a dataset, it does not get populated to the dataset _names variable.
Docstring update to be more descriptive, cleaner logic, removed double-space typo.
2 tasks
yyexela
commented
May 15, 2026
Comment on lines
+1252
to
+1266
| def test_requested_names_and_names_differ_when_variable_masked(mock_monthly_netcdfs): | ||
| mock_data: MockData = mock_monthly_netcdfs | ||
| config = XarrayDataConfig(data_path=mock_data.tmpdir) | ||
| existing_names = list(mock_data.var_names.time_dependent_names) | ||
| names_with_missing = existing_names + ["nonexistent_var"] | ||
| dataset = XarrayDataset( | ||
| config, | ||
| names_with_missing, | ||
| IntSchedule.from_constant(2), | ||
| allow_variable_masking=True, | ||
| ) | ||
| assert "nonexistent_var" in dataset._requested_names | ||
| assert "nonexistent_var" not in dataset._names | ||
|
|
||
|
|
Author
There was a problem hiding this comment.
Note: This test is used to show that the _requested_names variable could have more names than present in a dataset, which is populated in _names.
yyexela
commented
May 15, 2026
Comment on lines
-753
to
-755
| ds = self._open_file(0) | ||
| self._get_variable_metadata(ds) | ||
| ds.close() |
Author
There was a problem hiding this comment.
This was moved to the bottom of __init__ to ensure _names is updated before this is called. I think it looks cleaner this way anyways.
Author
|
Added detailed explanations of all changes to #1160 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR fixes two related bugs in the masked-variable pipeline and cleans up the loss masking implementation.
Bug fix —
XarrayDataset._namesinconsistency (fme/core/dataset/xarray.py)When
allow_variable_masking=True,_nameswas being silently overwritten in__init__to only contain thevariables actually found on disk, diverging from the originally requested names. This caused downstream logic
(e.g.
_group_variable_names_by_time_type) to skip missing variables even when masking was expected to handlethem. The fix renames the constructor-stored attribute to
_requested_names(used only for grouping) and makes_namesa computed property that always returns the union of time-dependent, time-invariant, and static-derivedvariables discovered in the dataset.
Cleanup — remove
_apply_output_mask(fme/ace/stepper/single_module.py)The
_apply_output_maskhelper, which zeroed out masked tensors before the loss call, is removed.WeightedMappingLossalready handles per-sample masking viadata_maskdirectly, so the pre-zeroing step wasredundant and added surface area for subtle mismatches.
Minor improvements
_collate_with_maskinginbatch_data.pysimplified to a singledict.fromkeysexpression._reduce_to_per_channelandBatchDataexpanded for clarity.test_xarray.py,test_step.py,test_loss.py,test_batch_data.py, andtest_data_loader.pymademore robust, descriptive, and comprehensive for masked-variable scenarios.
Changes:
fme.core.dataset.xarray.XarrayDataset— fixed_namesinconsistency whenallow_variable_masking=True;_namesis now a computed property always reflecting discovered variables, with
_requested_namesused internally forgrouping
fme.ace.stepper.single_module._apply_output_mask,fme.ace.stepper.single_module.TrainStepper— removed_apply_output_maskand its call site; masking is handled byWeightedMappingLossdirectlyfme.ace.data_loading.batch_data._collate_with_masking,fme.ace.data_loading.batch_data.BatchData—simplified unique-name collection; improved docstring
fme.core.loss._reduce_to_per_channel— expanded docstringTests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated