|
4 | 4 | for storage, including flattening CGNS trees, inferring data types, and handling |
5 | 5 | parallel processing of sample shards. |
6 | 6 | """ |
| 7 | + |
7 | 8 | import hashlib |
8 | 9 | import logging |
9 | 10 | import multiprocessing as mp |
@@ -42,7 +43,7 @@ def infer_dtype(value: Any) -> dict[str, int | str]: |
42 | 43 | dt = "int64" |
43 | 44 | elif np.issubdtype(dtype, np.str_): |
44 | 45 | dt = "string" |
45 | | - elif np.issubdtype(dtype, np.dtype('S1')): |
| 46 | + elif np.issubdtype(dtype, np.dtype("S1")): |
46 | 47 | dt = "S1" |
47 | 48 | else: # pragma: no cover |
48 | 49 | raise ValueError(f"Unrecognized scalar dtype: {dtype}") |
@@ -475,20 +476,27 @@ def preprocess_splits( |
475 | 476 |
|
476 | 477 | split_n_samples[split_name] = n_samples_total |
477 | 478 |
|
478 | | - # Determine truly constant paths (same hash across all samples) |
479 | | - constant_paths = [ |
480 | | - p |
481 | | - for p, entry in split_constant_hashes.items() |
482 | | - if len(entry["hashes"]) == 1 and entry["count"] == n_samples_total |
483 | | - ] |
484 | | - |
485 | 479 | # Retrieve **values** only for constant paths from first sample |
486 | 480 | if gen_kwargs: |
487 | 481 | first_sample = next(generator_fn([shards_ids_list[0]])) # pragma: no cover |
488 | 482 | else: |
489 | 483 | first_sample = next(generator_fn()) |
490 | 484 | sample_dict, _, _ = build_sample_dict(first_sample) |
491 | 485 |
|
| 486 | + # Determine truly constant paths (same hash across all samples). A split |
| 487 | + # with a single sample has no cross-sample repetition to prove that a |
| 488 | + # value is constant. Keep only None-valued structural paths as constants |
| 489 | + # so sample-based backends still have typed per-sample data columns. |
| 490 | + # this make possible to work with dataset with only one sample |
| 491 | + if n_samples_total <= 1: |
| 492 | + constant_paths = [p for p, value in sample_dict.items() if value is None] |
| 493 | + else: |
| 494 | + constant_paths = [ |
| 495 | + p |
| 496 | + for p, entry in split_constant_hashes.items() |
| 497 | + if len(entry["hashes"]) == 1 and entry["count"] == n_samples_total |
| 498 | + ] |
| 499 | + |
492 | 500 | split_flat_cst[split_name] = {p: sample_dict[p] for p in sorted(constant_paths)} |
493 | 501 | split_var_path[split_name] = { |
494 | 502 | p |
@@ -538,7 +546,7 @@ def preprocess( |
538 | 546 |
|
539 | 547 | # --- build features --- |
540 | 548 | var_features = sorted(list(set().union(*split_var_path.values()))) |
541 | | - #if len(var_features) == 0: # pragma: no cover |
| 549 | + # if len(var_features) == 0: # pragma: no cover |
542 | 550 | # raise ValueError( |
543 | 551 | # "no variable feature found, is your dataset variable through samples?" |
544 | 552 | # ) |
|
0 commit comments