Skip to content

Commit 3d68bf2

Browse files
committed
wip
1 parent 82c490c commit 3d68bf2

3 files changed

Lines changed: 62 additions & 7 deletions

File tree

src/plaid/infos.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ class DataProduction(
4141
)
4242

4343

44+
def _num_samples_sort_key(key: str) -> tuple[int, str]:
45+
"""Return the serialization sort key for ``num_samples`` split names."""
46+
if key.startswith("train"):
47+
return (0, key)
48+
if key.startswith("test"):
49+
return (1, key)
50+
return (2, key)
51+
52+
53+
def _sort_num_samples_keys(num_samples: dict[str, int]) -> dict[str, int]:
54+
"""Sort ``num_samples`` with train* keys first, then test*, then others."""
55+
return {
56+
key: num_samples[key] for key in sorted(num_samples, key=_num_samples_sort_key)
57+
}
58+
59+
4460
class Infos(
4561
BaseModel,
4662
revalidate_instances="always",
@@ -205,6 +221,8 @@ def save_to_file(self, path: Union[str, Path]) -> None:
205221
path.parent.mkdir(parents=True, exist_ok=True)
206222

207223
data = self.model_dump(exclude_none=True, exclude_unset=True)
224+
if "num_samples" in data:
225+
data["num_samples"] = _sort_num_samples_keys(data["num_samples"])
208226
ordered_data = {key: data[key] for key in _KEY_ORDER if key in data}
209227

210228
# Preserve any future fields.

src/plaid/storage/common/writer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,16 @@ def push_local_problem_definitions_to_hub(
304304

305305
api = HfApi()
306306

307-
api.upload_folder(
308-
folder_path=path / Path("problem_definitions"),
309-
repo_id=repo_id,
310-
repo_type="dataset",
311-
path_in_repo="problem_definitions",
312-
commit_message="Upload problem_definitions",
313-
)
307+
try:
308+
api.upload_folder(
309+
folder_path=path / Path("problem_definitions"),
310+
repo_id=repo_id,
311+
repo_type="dataset",
312+
path_in_repo="problem_definitions",
313+
commit_message="Upload problem_definitions",
314+
)
315+
except ValueError:
316+
logger.info(f"No problem definition in folder {path}")
314317

315318

316319
def push_local_metadata_to_hub(

tests/test_info.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,40 @@ def test_infos_save_and_load_roundtrip(tmp_path):
170170
assert reloaded.num_samples == {"train": 10}
171171

172172

173+
def test_infos_save_to_file_sorts_num_samples_keys(tmp_path):
174+
model = Infos.model_validate(
175+
_valid_infos(
176+
num_samples={
177+
"validation": 3,
178+
"test_b": 2,
179+
"train_b": 5,
180+
"other": 1,
181+
"test_a": 4,
182+
"train_a": 6,
183+
"train": 7,
184+
"test": 8,
185+
}
186+
)
187+
)
188+
189+
target = tmp_path / "infos.yaml"
190+
model.save_to_file(target)
191+
lines = target.read_text(encoding="utf-8").splitlines()
192+
193+
start = lines.index("num_samples:") + 1
194+
end = lines.index("storage_backend: zarr")
195+
assert lines[start:end] == [
196+
" train: 7",
197+
" train_a: 6",
198+
" train_b: 5",
199+
" test: 8",
200+
" test_a: 4",
201+
" test_b: 2",
202+
" other: 1",
203+
" validation: 3",
204+
]
205+
206+
173207
def test_infos_from_path_rejects_directory(tmp_path):
174208
Infos.model_validate(_valid_infos()).save_to_file(tmp_path / "infos.yaml")
175209

0 commit comments

Comments
 (0)