diff --git a/src/plaid/infos.py b/src/plaid/infos.py index a5180295..159ab668 100644 --- a/src/plaid/infos.py +++ b/src/plaid/infos.py @@ -41,6 +41,22 @@ class DataProduction( ) +def _num_samples_sort_key(key: str) -> tuple[int, str]: + """Return the serialization sort key for ``num_samples`` split names.""" + if key.startswith("train"): + return (0, key) + if key.startswith("test"): + return (1, key) + return (2, key) + + +def _sort_num_samples_keys(num_samples: dict[str, int]) -> dict[str, int]: + """Sort ``num_samples`` with train* keys first, then test*, then others.""" + return { + key: num_samples[key] for key in sorted(num_samples, key=_num_samples_sort_key) + } + + class Infos( BaseModel, revalidate_instances="always", @@ -205,6 +221,8 @@ def save_to_file(self, path: Union[str, Path]) -> None: path.parent.mkdir(parents=True, exist_ok=True) data = self.model_dump(exclude_none=True, exclude_unset=True) + if "num_samples" in data: + data["num_samples"] = _sort_num_samples_keys(data["num_samples"]) ordered_data = {key: data[key] for key in _KEY_ORDER if key in data} # Preserve any future fields. diff --git a/src/plaid/storage/common/reader.py b/src/plaid/storage/common/reader.py index dbca8407..ed7f5a75 100644 --- a/src/plaid/storage/common/reader.py +++ b/src/plaid/storage/common/reader.py @@ -54,7 +54,7 @@ def load_infos_from_disk(path: Union[str, Path]) -> Infos: def load_problem_definitions_from_disk( path: Union[str, Path], -) -> dict[str, ProblemDefinition]: +) -> dict[str, ProblemDefinition] | None: """Load ProblemDefinitions from a local dataset directory. This function reads all serialized ``ProblemDefinition`` files located in the @@ -76,13 +76,9 @@ def load_problem_definitions_from_disk( Root dataset directory containing the ``problem_definitions/`` folder. Returns: - dict[str, ProblemDefinition]: + dict[str, ProblemDefinition] | None: Mapping from problem definition filename stems to loaded ``ProblemDefinition`` objects. - - Raises: - ValueError: - If the ``problem_definitions/`` directory does not exist. """ pb_def_dir = Path(path).absolute() if pb_def_dir.name != "problem_definitions": @@ -102,10 +98,6 @@ def load_problem_definitions_from_disk( ) from exc pb_defs[p.stem] = pb_def return pb_defs - else: - raise ValueError( - f"No problem definitions found on disk. path '{pb_def_dir}'" - ) # pragma: no cover def load_constants_from_disk( diff --git a/src/plaid/storage/common/writer.py b/src/plaid/storage/common/writer.py index fcc03232..85fec3c2 100644 --- a/src/plaid/storage/common/writer.py +++ b/src/plaid/storage/common/writer.py @@ -304,13 +304,16 @@ def push_local_problem_definitions_to_hub( api = HfApi() - api.upload_folder( - folder_path=path / Path("problem_definitions"), - repo_id=repo_id, - repo_type="dataset", - path_in_repo="problem_definitions", - commit_message="Upload problem_definitions", - ) + try: + api.upload_folder( + folder_path=path / Path("problem_definitions"), + repo_id=repo_id, + repo_type="dataset", + path_in_repo="problem_definitions", + commit_message="Upload problem_definitions", + ) + except ValueError: + logger.info(f"No problem definition in folder {path}") def push_local_metadata_to_hub( diff --git a/tests/storage/test_storage.py b/tests/storage/test_storage.py index e9727767..9bd5042b 100644 --- a/tests/storage/test_storage.py +++ b/tests/storage/test_storage.py @@ -289,8 +289,6 @@ def test_hf_datasets( loaded_pb_defs = load_problem_definitions_from_disk(test_dir) assert set(loaded_pb_defs) == {"pb_def"} - with pytest.raises(ValueError): - load_problem_definitions_from_disk("dummy") datasetdict, converterdict = init_from_disk(test_dir, splits=["train"]) datasetdict, converterdict = init_from_disk(test_dir) diff --git a/tests/test_info.py b/tests/test_info.py index ae2d2e6d..b2ab6b65 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -170,6 +170,40 @@ def test_infos_save_and_load_roundtrip(tmp_path): assert reloaded.num_samples == {"train": 10} +def test_infos_save_to_file_sorts_num_samples_keys(tmp_path): + model = Infos.model_validate( + _valid_infos( + num_samples={ + "validation": 3, + "test_b": 2, + "train_b": 5, + "other": 1, + "test_a": 4, + "train_a": 6, + "train": 7, + "test": 8, + } + ) + ) + + target = tmp_path / "infos.yaml" + model.save_to_file(target) + lines = target.read_text(encoding="utf-8").splitlines() + + start = lines.index("num_samples:") + 1 + end = lines.index("storage_backend: zarr") + assert lines[start:end] == [ + " train: 7", + " train_a: 6", + " train_b: 5", + " test: 8", + " test_a: 4", + " test_b: 2", + " other: 1", + " validation: 3", + ] + + def test_infos_from_path_rejects_directory(tmp_path): Infos.model_validate(_valid_infos()).save_to_file(tmp_path / "infos.yaml")