Skip to content

Commit 498aaf4

Browse files
🎉 (plaid-check) add a Global shape check for dataset consistency
1 parent 29047d2 commit 498aaf4

2 files changed

Lines changed: 73 additions & 0 deletions

File tree

src/plaid/cli/plaidcheck.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,10 @@ def check_dataset(
524524
target_splits = target_splits & dataset_splits
525525

526526
checksum_report = {}
527+
# Track shape of each Global feature across all checked splits/samples to
528+
# detect inconsistencies (e.g. a Global stored as a scalar in one sample
529+
# and as a vector in another).
530+
global_shape_observations: dict[str, dict[tuple, list[str]]] = {}
527531
for split in sorted(target_splits):
528532
dataset = datasetdict[split]
529533
converter = converterdict[split]
@@ -595,6 +599,19 @@ def check_dataset(
595599
issue,
596600
)
597601

602+
# Record the observed shape of this Global so we can later
603+
# detect dimension mismatches across all checked samples
604+
# (across splits).
605+
if value is not None:
606+
try:
607+
shape = tuple(np.asarray(value).shape)
608+
except Exception:
609+
shape = None
610+
if shape is not None:
611+
global_shape_observations.setdefault(
612+
global_name, {}
613+
).setdefault(shape, []).append(f"{split}[{idx}]")
614+
598615
for time in sample.get_all_time_values():
599616
local_bases = sample.get_base_names(time=time)
600617
for base in local_bases:
@@ -625,6 +642,25 @@ def check_dataset(
625642
issue,
626643
)
627644

645+
# Report Globals whose dimension/shape is not consistent across all
646+
# checked samples (across splits).
647+
for global_name, shape_to_locations in global_shape_observations.items():
648+
if len(shape_to_locations) <= 1:
649+
continue
650+
details = "; ".join(
651+
f"shape={shape} at {locations[:5]}"
652+
+ (f" (+{len(locations) - 5} more)" if len(locations) > 5 else "")
653+
for shape, locations in sorted(
654+
shape_to_locations.items(), key=lambda kv: str(kv[0])
655+
)
656+
)
657+
report.add(
658+
"error",
659+
"GLOBAL_SHAPE_MISMATCH",
660+
f"global/{global_name}",
661+
f"Global '{global_name}' has inconsistent shapes across samples: {details}",
662+
)
663+
628664
# Compare checksums from every checked sample to flag identical sample data.
629665
checksum_values = list(checksum_report.values())
630666
if len(checksum_report) != len(np.unique(checksum_values)):

tests/cli/test_plaidcheck.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,43 @@ def test_check_dataset_split_and_data_warnings_and_duplicates(
633633
assert any(msg.code == "DUPLICATED_DATA" for msg in report.messages)
634634

635635

636+
def test_check_dataset_reports_global_shape_mismatch(
637+
tmp_path: Path, monkeypatch
638+
) -> None:
639+
"""Globals with inconsistent shapes across samples should be reported."""
640+
dataset = _make_minimal_layout(tmp_path)
641+
642+
monkeypatch.setattr(
643+
plaidcheck,
644+
"load_infos_from_disk",
645+
lambda path: _infos({"train": 2}), # noqa: ARG005
646+
)
647+
monkeypatch.setattr(
648+
plaidcheck,
649+
"load_metadata_from_disk",
650+
lambda path: ({"train": {}}, {"Var": {}}, {"train": {}}, None), # noqa: ARG005
651+
)
652+
samples = [
653+
_FakeSampleForCheck(global_value=np.array(1.0)),
654+
_FakeSampleForCheck(global_value=np.array([1.0, 2.0])),
655+
]
656+
monkeypatch.setattr(
657+
plaidcheck,
658+
"init_from_disk",
659+
lambda path: ( # noqa: ARG005
660+
{"train": _FakeDataset(2)},
661+
{"train": _FakeConverter(samples)},
662+
),
663+
)
664+
665+
report = check_dataset(dataset, splits=["train"])
666+
667+
mismatch_msgs = [m for m in report.messages if m.code == "GLOBAL_SHAPE_MISMATCH"]
668+
assert len(mismatch_msgs) == 1
669+
assert mismatch_msgs[0].severity == "error"
670+
assert "G" in mismatch_msgs[0].location
671+
672+
636673
def test_check_dataset_sample_conversion_error(tmp_path: Path, monkeypatch) -> None:
637674
"""Checker should emit conversion errors when converter fails on an index."""
638675
dataset = _make_minimal_layout(tmp_path)

0 commit comments

Comments
 (0)