Skip to content

Commit bbfbc21

Browse files
committed
support for dataset with only 1 sample
1 parent 3fa23c2 commit bbfbc21

2 files changed

Lines changed: 49 additions & 9 deletions

File tree

src/plaid/storage/common/preprocessor.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
for storage, including flattening CGNS trees, inferring data types, and handling
55
parallel processing of sample shards.
66
"""
7+
78
import hashlib
89
import logging
910
import multiprocessing as mp
@@ -42,7 +43,7 @@ def infer_dtype(value: Any) -> dict[str, int | str]:
4243
dt = "int64"
4344
elif np.issubdtype(dtype, np.str_):
4445
dt = "string"
45-
elif np.issubdtype(dtype, np.dtype('S1')):
46+
elif np.issubdtype(dtype, np.dtype("S1")):
4647
dt = "S1"
4748
else: # pragma: no cover
4849
raise ValueError(f"Unrecognized scalar dtype: {dtype}")
@@ -475,20 +476,27 @@ def preprocess_splits(
475476

476477
split_n_samples[split_name] = n_samples_total
477478

478-
# Determine truly constant paths (same hash across all samples)
479-
constant_paths = [
480-
p
481-
for p, entry in split_constant_hashes.items()
482-
if len(entry["hashes"]) == 1 and entry["count"] == n_samples_total
483-
]
484-
485479
# Retrieve **values** only for constant paths from first sample
486480
if gen_kwargs:
487481
first_sample = next(generator_fn([shards_ids_list[0]])) # pragma: no cover
488482
else:
489483
first_sample = next(generator_fn())
490484
sample_dict, _, _ = build_sample_dict(first_sample)
491485

486+
# Determine truly constant paths (same hash across all samples). A split
487+
# with a single sample has no cross-sample repetition to prove that a
488+
# value is constant. Keep only None-valued structural paths as constants
489+
# so sample-based backends still have typed per-sample data columns.
490+
# this make possible to work with dataset with only one sample
491+
if n_samples_total <= 1:
492+
constant_paths = [p for p, value in sample_dict.items() if value is None]
493+
else:
494+
constant_paths = [
495+
p
496+
for p, entry in split_constant_hashes.items()
497+
if len(entry["hashes"]) == 1 and entry["count"] == n_samples_total
498+
]
499+
492500
split_flat_cst[split_name] = {p: sample_dict[p] for p in sorted(constant_paths)}
493501
split_var_path[split_name] = {
494502
p
@@ -538,7 +546,7 @@ def preprocess(
538546

539547
# --- build features ---
540548
var_features = sorted(list(set().union(*split_var_path.values())))
541-
#if len(var_features) == 0: # pragma: no cover
549+
# if len(var_features) == 0: # pragma: no cover
542550
# raise ValueError(
543551
# "no variable feature found, is your dataset variable through samples?"
544552
# )

tests/storage/test_storage.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,38 @@ def assert_sample(self, sample):
201201
assert "test_field_same_size" in sample.get_field_names()
202202
assert sample.get_field("test_field_same_size").shape[0] == 17
203203

204+
@pytest.mark.parametrize("backend", ["cgns", "hf_datasets", "zarr"])
205+
def test_single_sample_dataset_roundtrip(
206+
self,
207+
backend,
208+
tmp_path,
209+
samples_with_extra_global,
210+
infos,
211+
):
212+
"""HF datasets and Zarr backends support splits with one sample."""
213+
test_dir = tmp_path / f"test_single_sample_{backend}"
214+
215+
def single_sample_constructor(sample_id):
216+
return samples_with_extra_global[sample_id]
217+
218+
save_to_disk(
219+
output_folder=test_dir,
220+
sample_constructor=single_sample_constructor,
221+
ids={"train": [0]},
222+
backend=backend,
223+
infos=infos,
224+
overwrite=True,
225+
)
226+
227+
datasetdict, converterdict = init_from_disk(test_dir)
228+
dataset = datasetdict["train"]
229+
converter = converterdict["train"]
230+
231+
assert len(dataset) == 1
232+
assert converter.num_samples == 1
233+
self.assert_sample(converter.to_plaid(dataset, 0))
234+
self.assert_sample(converter.sample_to_plaid(dataset[0]))
235+
204236
# ------------------------------------------------------------------------------
205237
# HUGGING FACE BRIDGE (with tree flattening and pyarrow tables)
206238
# ------------------------------------------------------------------------------

0 commit comments

Comments
 (0)