Skip to content

Fix masked variable handling#1174

Open
yyexela wants to merge 19 commits into
mainfrom
feature/masked_variables_changes
Open

Fix masked variable handling#1174
yyexela wants to merge 19 commits into
mainfrom
feature/masked_variables_changes

Conversation

@yyexela
Copy link
Copy Markdown

@yyexela yyexela commented May 14, 2026

This PR fixes two related bugs in the masked-variable pipeline and cleans up the loss masking implementation.

Bug fix — XarrayDataset._names inconsistency (fme/core/dataset/xarray.py)

When allow_variable_masking=True, _names was being silently overwritten in __init__ to only contain the
variables actually found on disk, diverging from the originally requested names. This caused downstream logic
(e.g. _group_variable_names_by_time_type) to skip missing variables even when masking was expected to handle
them. The fix renames the constructor-stored attribute to _requested_names (used only for grouping) and makes
_names a computed property that always returns the union of time-dependent, time-invariant, and static-derived
variables discovered in the dataset.

Cleanup — remove _apply_output_mask (fme/ace/stepper/single_module.py)

The _apply_output_mask helper, which zeroed out masked tensors before the loss call, is removed.
WeightedMappingLoss already handles per-sample masking via data_mask directly, so the pre-zeroing step was
redundant and added surface area for subtle mismatches.

Minor improvements

  • _collate_with_masking in batch_data.py simplified to a single dict.fromkeys expression.
  • Docstrings for _reduce_to_per_channel and BatchData expanded for clarity.
  • Tests across test_xarray.py, test_step.py, test_loss.py, test_batch_data.py, and test_data_loader.py made
    more robust, descriptive, and comprehensive for masked-variable scenarios.

Changes:

  • fme.core.dataset.xarray.XarrayDataset — fixed _names inconsistency when allow_variable_masking=True; _names
    is now a computed property always reflecting discovered variables, with _requested_names used internally for
    grouping

  • fme.ace.stepper.single_module._apply_output_mask, fme.ace.stepper.single_module.TrainStepper — removed
    _apply_output_mask and its call site; masking is handled by WeightedMappingLoss directly

  • fme.ace.data_loading.batch_data._collate_with_masking, fme.ace.data_loading.batch_data.BatchData
    simplified unique-name collection; improved docstring

  • fme.core.loss._reduce_to_per_channel — expanded docstring

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

mcgibbon and others added 19 commits May 12, 2026 19:44
Outlines the plan for per-sample variable masking in ACE training,
inference, and data loading, enabling training on heterogeneous data
sources where some datasets are missing variables.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add a `data_mask: TensorMapping | None` field to BatchData for
per-sample, per-variable masking. Propagate it through all methods
that create new BatchData instances (to_device, to_cpu, broadcast_ensemble,
subset_names, etc). Add validation and tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Wire the flag through SingleModuleStepConfig -> StepConfigABC ->
StepSelector -> StepperConfig -> DataRequirements. The flag is a
safety gate: when False (default), missing required variables cause
an error. When True, the data pipeline may omit variables and provide
masks instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Thread data_mask from BatchData through predict_generator and
_accumulate_loss into StepArgs, so it is available to step
implementations for input and output masking.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Zero out masked input variables in normalized space before the network
call (_apply_input_mask) and zero out both predictions and targets for
masked output variables before loss computation (_apply_output_mask).
This ensures masked variables don't influence training gradients.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When allow_variable_masking is True, XarrayDataset skips variables not
found in the data files instead of raising an error. CollateFn handles
heterogeneous variable sets across samples by taking the union of names,
filling missing variables with zeros, and producing a data_mask.

The allow_variable_masking flag is threaded from DataRequirements through
get_gridded_data -> DataLoaderConfig -> dataset build chain -> XarrayDataset,
and also to CollateFn -> BatchData.from_sample_tuples.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pass allow_variable_masking from DataRequirements through to
InferenceDataset, which forwards it to XarrayDataset construction
and BatchData.from_sample_tuples collation. Also update
get_forcing_data to pass the flag when constructing XarrayDataset
for time index discovery.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…s chain

Fill missing variables with NaN instead of zeros during collation so that
the existing NaN masking in WeightedMappingLoss serves as a safety net.
Propagate data_mask through StepLoss -> WeightedMappingLoss -> LossOutput
so that total() uses a masked mean, preventing absent variables from
diluting the optimization target. Guard _get_variable_metadata with
allow_variable_masking to raise ValueError for missing variables when
masking is disabled.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Use per-sample data_mask (pre-broadcast) instead of the
ensemble-broadcast mask for _apply_output_mask and StepLoss, since
gen_step/target_step have been unfolded back to [batch, ensemble, ...].
The ensemble-broadcast mask had shape [batch*n_ensemble] which doesn't
broadcast against the unfolded tensors when n_ensemble > 1.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ve tests

- In _collate_with_masking, cast non-floating-point dtypes to float32
  before filling with NaN to avoid RuntimeError on integer tensors.
- Delete masked-variables-design.md from repo root (content belongs in
  the PR description).
- Add test for integer dtype handling in collation.
- Strengthen concat masking test to verify data_mask contents.
- Clarify inference masking test expectations.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When enabled, appends per-variable mask indicator channels (1.0 =
present, 0.0 = masked) to the packed network input, doubling the
input channel count. This lets the model distinguish between physically
zero inputs and masked (zeroed) inputs.

The mask channels are built separately from the data packer and
concatenated along the channel dim after packing, keeping the packer
clean and avoiding normalization of the mask values.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Added information about expected loss_value and mask shapes in _reduce_to_per_channel docstring, making it more clear the expected input sizes. Also modified test_step.py asserts to be more readable.
…asking=True

Previously, self._names was initialized from the caller-supplied names list and
then manually reconciled after _group_variable_names_by_time_type() to drop
variables missing from the dataset. This left a window where self._names and the
three partitioned lists were out of sync, and required defensive guards in
_get_variable_metadata to handle names that had already been silently skipped.

Replace self._names with a property derived from the three discovered lists
(_time_dependent_names, _time_invariant_names, _static_derived_names), which are
the single source of truth after discovery. The original requested names are
stored in self._requested_names for use only inside
_group_variable_names_by_time_type. Move the _get_variable_metadata call to after
discovery so it iterates only over found variables, allowing removal of the
missing-variable guards there.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ames

The added unit test demonstrates that when a requested variable name is not present in a dataset, it does not get populated to the dataset _names variable.
Docstring update to be more descriptive, cleaner logic, removed double-space typo.
Comment on lines +1252 to +1266
def test_requested_names_and_names_differ_when_variable_masked(mock_monthly_netcdfs):
mock_data: MockData = mock_monthly_netcdfs
config = XarrayDataConfig(data_path=mock_data.tmpdir)
existing_names = list(mock_data.var_names.time_dependent_names)
names_with_missing = existing_names + ["nonexistent_var"]
dataset = XarrayDataset(
config,
names_with_missing,
IntSchedule.from_constant(2),
allow_variable_masking=True,
)
assert "nonexistent_var" in dataset._requested_names
assert "nonexistent_var" not in dataset._names


Copy link
Copy Markdown
Author

@yyexela yyexela May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This test is used to show that the _requested_names variable could have more names than present in a dataset, which is populated in _names.

Comment thread fme/core/dataset/xarray.py Outdated
Comment on lines -753 to -755
ds = self._open_file(0)
self._get_variable_metadata(ds)
ds.close()
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved to the bottom of __init__ to ensure _names is updated before this is called. I think it looks cleaner this way anyways.

@yyexela
Copy link
Copy Markdown
Author

yyexela commented May 15, 2026

Added detailed explanations of all changes to #1160

Base automatically changed from feature/masked_variables to main May 20, 2026 21:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants