Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/plaid/infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ class DataProduction(
)


def _num_samples_sort_key(key: str) -> tuple[int, str]:
"""Return the serialization sort key for ``num_samples`` split names."""
if key.startswith("train"):
return (0, key)
if key.startswith("test"):
return (1, key)
return (2, key)


def _sort_num_samples_keys(num_samples: dict[str, int]) -> dict[str, int]:
"""Sort ``num_samples`` with train* keys first, then test*, then others."""
return {
key: num_samples[key] for key in sorted(num_samples, key=_num_samples_sort_key)
}


class Infos(
BaseModel,
revalidate_instances="always",
Expand Down Expand Up @@ -205,6 +221,8 @@ def save_to_file(self, path: Union[str, Path]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)

data = self.model_dump(exclude_none=True, exclude_unset=True)
if "num_samples" in data:
data["num_samples"] = _sort_num_samples_keys(data["num_samples"])
ordered_data = {key: data[key] for key in _KEY_ORDER if key in data}

# Preserve any future fields.
Expand Down
12 changes: 2 additions & 10 deletions src/plaid/storage/common/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def load_infos_from_disk(path: Union[str, Path]) -> Infos:

def load_problem_definitions_from_disk(
path: Union[str, Path],
) -> dict[str, ProblemDefinition]:
) -> dict[str, ProblemDefinition] | None:
"""Load ProblemDefinitions from a local dataset directory.

This function reads all serialized ``ProblemDefinition`` files located in the
Expand All @@ -76,13 +76,9 @@ def load_problem_definitions_from_disk(
Root dataset directory containing the ``problem_definitions/`` folder.

Returns:
dict[str, ProblemDefinition]:
dict[str, ProblemDefinition] | None:
Mapping from problem definition filename stems to loaded ``ProblemDefinition``
objects.

Raises:
ValueError:
If the ``problem_definitions/`` directory does not exist.
"""
pb_def_dir = Path(path).absolute()
if pb_def_dir.name != "problem_definitions":
Expand All @@ -102,10 +98,6 @@ def load_problem_definitions_from_disk(
) from exc
pb_defs[p.stem] = pb_def
return pb_defs
else:
raise ValueError(
f"No problem definitions found on disk. path '{pb_def_dir}'"
) # pragma: no cover


def load_constants_from_disk(
Expand Down
17 changes: 10 additions & 7 deletions src/plaid/storage/common/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,16 @@ def push_local_problem_definitions_to_hub(

api = HfApi()

api.upload_folder(
folder_path=path / Path("problem_definitions"),
repo_id=repo_id,
repo_type="dataset",
path_in_repo="problem_definitions",
commit_message="Upload problem_definitions",
)
try:
api.upload_folder(
folder_path=path / Path("problem_definitions"),
repo_id=repo_id,
repo_type="dataset",
path_in_repo="problem_definitions",
commit_message="Upload problem_definitions",
)
except ValueError:
logger.info(f"No problem definition in folder {path}")


def push_local_metadata_to_hub(
Expand Down
2 changes: 0 additions & 2 deletions tests/storage/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,6 @@ def test_hf_datasets(

loaded_pb_defs = load_problem_definitions_from_disk(test_dir)
assert set(loaded_pb_defs) == {"pb_def"}
with pytest.raises(ValueError):
load_problem_definitions_from_disk("dummy")

datasetdict, converterdict = init_from_disk(test_dir, splits=["train"])
datasetdict, converterdict = init_from_disk(test_dir)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,40 @@ def test_infos_save_and_load_roundtrip(tmp_path):
assert reloaded.num_samples == {"train": 10}


def test_infos_save_to_file_sorts_num_samples_keys(tmp_path):
model = Infos.model_validate(
_valid_infos(
num_samples={
"validation": 3,
"test_b": 2,
"train_b": 5,
"other": 1,
"test_a": 4,
"train_a": 6,
"train": 7,
"test": 8,
}
)
)

target = tmp_path / "infos.yaml"
model.save_to_file(target)
lines = target.read_text(encoding="utf-8").splitlines()

start = lines.index("num_samples:") + 1
end = lines.index("storage_backend: zarr")
assert lines[start:end] == [
" train: 7",
" train_a: 6",
" train_b: 5",
" test: 8",
" test_a: 4",
" test_b: 2",
" other: 1",
" validation: 3",
]


def test_infos_from_path_rejects_directory(tmp_path):
Infos.model_validate(_valid_infos()).save_to_file(tmp_path / "infos.yaml")

Expand Down
Loading