Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"pyyaml>=6,<7",
"pycgns>=6.3,<7",
"zarr>=3.1,<4",
"datasets>=2.18,<5",
"datasets>=2.18,<6",
"numpy>=2.0,<3",
"pydantic>=2.6,<3",
]
Expand Down
34 changes: 34 additions & 0 deletions src/plaid/cli/plaidcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,10 @@ def check_dataset(
target_splits = target_splits & dataset_splits

checksum_report = {}
# Track shape of each Global feature across all checked splits/samples to
# detect inconsistencies (e.g. a Global stored as a scalar in one sample
# and as a vector in another).
global_shape_observations: dict[str, dict[tuple, list[str]]] = {}
for split in sorted(target_splits):
dataset = datasetdict[split]
converter = converterdict[split]
Expand Down Expand Up @@ -595,6 +599,17 @@ def check_dataset(
issue,
)

# Record the observed shape of this Global so we can later
# detect dimension mismatches across all checked samples
# (across splits). At this point ``_check_numeric_content``
# already coerced ``value`` through ``np.asarray`` without
# error, so the same call here is safe.
if value is not None:
shape = tuple(np.asarray(value).shape)
global_shape_observations.setdefault(global_name, {}).setdefault(
shape, []
).append(f"{split}[{idx}]")

for time in sample.get_all_time_values():
local_bases = sample.get_base_names(time=time)
for base in local_bases:
Expand Down Expand Up @@ -625,6 +640,25 @@ def check_dataset(
issue,
)

# Report Globals whose dimension/shape is not consistent across all
# checked samples (across splits).
for global_name, shape_to_locations in global_shape_observations.items():
if len(shape_to_locations) <= 1:
continue
details = "; ".join(
f"shape={shape} at {locations[:5]}"
+ (f" (+{len(locations) - 5} more)" if len(locations) > 5 else "")
for shape, locations in sorted(
shape_to_locations.items(), key=lambda kv: str(kv[0])
)
)
report.add(
"error",
"GLOBAL_SHAPE_MISMATCH",
f"global/{global_name}",
f"Global '{global_name}' has inconsistent shapes across samples: {details}",
)

# Compare checksums from every checked sample to flag identical sample data.
checksum_values = list(checksum_report.values())
if len(checksum_report) != len(np.unique(checksum_values)):
Expand Down
37 changes: 37 additions & 0 deletions tests/cli/test_plaidcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,43 @@ def test_check_dataset_split_and_data_warnings_and_duplicates(
assert any(msg.code == "DUPLICATED_DATA" for msg in report.messages)


def test_check_dataset_reports_global_shape_mismatch(
tmp_path: Path, monkeypatch
) -> None:
"""Globals with inconsistent shapes across samples should be reported."""
dataset = _make_minimal_layout(tmp_path)

monkeypatch.setattr(
plaidcheck,
"load_infos_from_disk",
lambda path: _infos({"train": 2}), # noqa: ARG005
)
monkeypatch.setattr(
plaidcheck,
"load_metadata_from_disk",
lambda path: ({"train": {}}, {"Var": {}}, {"train": {}}, None), # noqa: ARG005
)
samples = [
_FakeSampleForCheck(global_value=np.array(1.0)),
_FakeSampleForCheck(global_value=np.array([1.0, 2.0])),
]
monkeypatch.setattr(
plaidcheck,
"init_from_disk",
lambda path: ( # noqa: ARG005
{"train": _FakeDataset(2)},
{"train": _FakeConverter(samples)},
),
)

report = check_dataset(dataset, splits=["train"])

mismatch_msgs = [m for m in report.messages if m.code == "GLOBAL_SHAPE_MISMATCH"]
assert len(mismatch_msgs) == 1
assert mismatch_msgs[0].severity == "error"
assert "G" in mismatch_msgs[0].location


def test_check_dataset_sample_conversion_error(tmp_path: Path, monkeypatch) -> None:
"""Checker should emit conversion errors when converter fails on an index."""
dataset = _make_minimal_layout(tmp_path)
Expand Down
Loading
Loading