Skip to content

Commit ae6fc19

Browse files
authored
Fix label_override when forcing_loader not given labels (ai2cm#977)
During ace inference data loading, if forcing_loader does not have labels set, label_override breaks due to the lack of a label_encoding in the forcing dataset (see [beaker log](https://beaker.org/orgs/ai2/workspaces/ace/work/01KKN10MF6Q1NJJFE10GT099X4/logs?jobId=01KKN10MS6DX9GC4VFKHQ32J0T)). This PR resolves the issue by using label_override to define the label_encoding when it is used. This avoids the awkward position of needing to set both the label at the highest level of InferenceConfig and in the forcing dataset. What this PR doesn't address is the potential misconfiguration of labels in `InferenceConfig` is set to ERA5, but the forcing dataset label was mistakenly set to a different label. We should address this in a future PR. Changes: - Create a `LabelEncoding` if the dataset label isn't set but the `label_override` is not None - Add test to `test_data_loader.py‎` - [x] Tests added
1 parent e709930 commit ae6fc19

2 files changed

Lines changed: 29 additions & 0 deletions

File tree

fme/ace/data_loading/inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ def __init__(
225225
"""
226226
if label_encoding is None and config.available_labels is not None:
227227
label_encoding = LabelEncoding(labels=sorted(list(config.available_labels)))
228+
elif label_encoding is None and label_override is not None:
229+
# When labels are overridden (e.g. from config.labels), we still need
230+
# an encoding to collate them even if the dataset has no available_labels.
231+
label_encoding = LabelEncoding(labels=sorted(list(label_override)))
228232
self._label_encoding = label_encoding
229233
self._label_override = (
230234
set(label_override) if label_override is not None else None

fme/ace/data_loading/test_data_loader.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,31 @@ def test_inference_persistence_names(tmp_path):
962962
assert not torch.all(first_item["bar"] == second_item["bar"])
963963

964964

965+
def test_inference_dataset_label_override_without_dataset_labels(tmp_path):
966+
_create_dataset_on_disk(tmp_path, n_times=14)
967+
config = InferenceDataLoaderConfig(
968+
dataset=XarrayDataConfig(data_path=tmp_path),
969+
start_indices=ExplicitIndices([0]),
970+
)
971+
window_requirements = DataRequirements(
972+
names=["foo", "bar"],
973+
n_timesteps=3,
974+
)
975+
dataset = InferenceDataset(
976+
config,
977+
total_forward_steps=3,
978+
requirements=window_requirements,
979+
label_override=["era5"],
980+
)
981+
batch = dataset[0]
982+
# assert that the labels are not None and are the correct labels set during
983+
# inference config initialization
984+
assert batch.labels is not None
985+
assert batch.labels.names == ["era5"]
986+
assert batch.labels.tensor.shape[0] == 1
987+
assert batch.labels.tensor.shape[1] == 1
988+
989+
965990
def test_zarr_engine_used_sequence():
966991
config = DataLoaderConfig(
967992
dataset=ConcatDatasetConfig(

0 commit comments

Comments
 (0)