Skip to content

Commit d5eaabb

Browse files
🎉 (plaid-check) add a Global shape check for dataset consistency (#448)
Co-authored-by: Fabien Casenave <fabien.casenave@safrangroup.com>
1 parent 29047d2 commit d5eaabb

4 files changed

Lines changed: 230 additions & 158 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"pyyaml>=6,<7",
3333
"pycgns>=6.3,<7",
3434
"zarr>=3.1,<4",
35-
"datasets>=2.18,<5",
35+
"datasets>=2.18,<6",
3636
"numpy>=2.0,<3",
3737
"pydantic>=2.6,<3",
3838
]

src/plaid/cli/plaidcheck.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,10 @@ def check_dataset(
524524
target_splits = target_splits & dataset_splits
525525

526526
checksum_report = {}
527+
# Track shape of each Global feature across all checked splits/samples to
528+
# detect inconsistencies (e.g. a Global stored as a scalar in one sample
529+
# and as a vector in another).
530+
global_shape_observations: dict[str, dict[tuple, list[str]]] = {}
527531
for split in sorted(target_splits):
528532
dataset = datasetdict[split]
529533
converter = converterdict[split]
@@ -595,6 +599,17 @@ def check_dataset(
595599
issue,
596600
)
597601

602+
# Record the observed shape of this Global so we can later
603+
# detect dimension mismatches across all checked samples
604+
# (across splits). At this point ``_check_numeric_content``
605+
# already coerced ``value`` through ``np.asarray`` without
606+
# error, so the same call here is safe.
607+
if value is not None:
608+
shape = tuple(np.asarray(value).shape)
609+
global_shape_observations.setdefault(global_name, {}).setdefault(
610+
shape, []
611+
).append(f"{split}[{idx}]")
612+
598613
for time in sample.get_all_time_values():
599614
local_bases = sample.get_base_names(time=time)
600615
for base in local_bases:
@@ -625,6 +640,25 @@ def check_dataset(
625640
issue,
626641
)
627642

643+
# Report Globals whose dimension/shape is not consistent across all
644+
# checked samples (across splits).
645+
for global_name, shape_to_locations in global_shape_observations.items():
646+
if len(shape_to_locations) <= 1:
647+
continue
648+
details = "; ".join(
649+
f"shape={shape} at {locations[:5]}"
650+
+ (f" (+{len(locations) - 5} more)" if len(locations) > 5 else "")
651+
for shape, locations in sorted(
652+
shape_to_locations.items(), key=lambda kv: str(kv[0])
653+
)
654+
)
655+
report.add(
656+
"error",
657+
"GLOBAL_SHAPE_MISMATCH",
658+
f"global/{global_name}",
659+
f"Global '{global_name}' has inconsistent shapes across samples: {details}",
660+
)
661+
628662
# Compare checksums from every checked sample to flag identical sample data.
629663
checksum_values = list(checksum_report.values())
630664
if len(checksum_report) != len(np.unique(checksum_values)):

tests/cli/test_plaidcheck.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,43 @@ def test_check_dataset_split_and_data_warnings_and_duplicates(
633633
assert any(msg.code == "DUPLICATED_DATA" for msg in report.messages)
634634

635635

636+
def test_check_dataset_reports_global_shape_mismatch(
637+
tmp_path: Path, monkeypatch
638+
) -> None:
639+
"""Globals with inconsistent shapes across samples should be reported."""
640+
dataset = _make_minimal_layout(tmp_path)
641+
642+
monkeypatch.setattr(
643+
plaidcheck,
644+
"load_infos_from_disk",
645+
lambda path: _infos({"train": 2}), # noqa: ARG005
646+
)
647+
monkeypatch.setattr(
648+
plaidcheck,
649+
"load_metadata_from_disk",
650+
lambda path: ({"train": {}}, {"Var": {}}, {"train": {}}, None), # noqa: ARG005
651+
)
652+
samples = [
653+
_FakeSampleForCheck(global_value=np.array(1.0)),
654+
_FakeSampleForCheck(global_value=np.array([1.0, 2.0])),
655+
]
656+
monkeypatch.setattr(
657+
plaidcheck,
658+
"init_from_disk",
659+
lambda path: ( # noqa: ARG005
660+
{"train": _FakeDataset(2)},
661+
{"train": _FakeConverter(samples)},
662+
),
663+
)
664+
665+
report = check_dataset(dataset, splits=["train"])
666+
667+
mismatch_msgs = [m for m in report.messages if m.code == "GLOBAL_SHAPE_MISMATCH"]
668+
assert len(mismatch_msgs) == 1
669+
assert mismatch_msgs[0].severity == "error"
670+
assert "G" in mismatch_msgs[0].location
671+
672+
636673
def test_check_dataset_sample_conversion_error(tmp_path: Path, monkeypatch) -> None:
637674
"""Checker should emit conversion errors when converter fails on an index."""
638675
dataset = _make_minimal_layout(tmp_path)

0 commit comments

Comments
 (0)