Skip to content

Commit e103bc7

Browse files
authored
F/multilable split readiness check (#278)
* decompose `_stratification.py` onto two files * implement safe stratification of multilabel data * fix typing * minor bug fix * update readiness util * upd again * update tests * refactor readiness util a little bit * bug fix multilabel stratification * widen utility's coverage * annotate stop iteration error when all samples are oos * add more tests * fix typing * detect 0-samples classes too
1 parent 9d63e1e commit e103bc7

6 files changed

Lines changed: 611 additions & 124 deletions

File tree

src/autointent/context/data_handler/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
from ._data_handler import DataHandler
2-
from ._stratification import (
3-
SplitReadinessResult,
4-
StratifiedSplitter,
5-
check_split_readiness,
6-
split_dataset,
7-
)
2+
from ._readiness_util import SplitReadinessResult, check_split_readiness
3+
from ._stratification import StratifiedSplitter, split_dataset
84

95
__all__ = [
106
"DataHandler",
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from __future__ import annotations
2+
3+
from collections import Counter
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, NamedTuple
6+
7+
import numpy as np
8+
9+
if TYPE_CHECKING:
10+
from datasets import Dataset as HFDataset
11+
12+
from autointent import Dataset
13+
from autointent.configs import DataConfig
14+
15+
from ._safe_multilabel_stratification import _validate_multilabel_matrix
16+
from ._stratification import StratifiedSplitter
17+
18+
19+
class ClassCount(NamedTuple):
20+
id: int
21+
"""Class (intent) index."""
22+
23+
n_samples: int
24+
"""Number of samples from the class (intent)."""
25+
26+
27+
@dataclass(frozen=True)
28+
class SplitReadinessResult:
29+
"""Result of checking whether a dataset can be fed to autointent pipeline.
30+
31+
Attributes:
32+
ready: True if stratification can be performed (enough samples per class).
33+
underpopulated_classes: List of (label, n_samples) for classes below the minimum.
34+
min_samples_per_class_required: Minimum samples per class used for the check.
35+
reason: Human-readable reason when not ready (e.g. OOS not configured).
36+
"""
37+
38+
ready: bool
39+
underpopulated_classes: list[ClassCount]
40+
min_samples_per_class_required: int
41+
reason: str | None
42+
43+
44+
def check_split_readiness(
45+
dataset: Dataset,
46+
split: str,
47+
config: DataConfig,
48+
allow_oos_in_train: bool | None = None,
49+
) -> SplitReadinessResult:
50+
"""Check whether the dataset has enough samples per class for autointent pipeline.
51+
52+
Args:
53+
dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`).
54+
split: The split name to check (e.g. ``Split.TRAIN``).
55+
config: data config
56+
allow_oos_in_train: Same as in :func:`split_dataset`. If the split contains OOS samples
57+
and this is ``None``, this function raises ``ValueError`` (mirrors splitting behavior).
58+
"""
59+
min_samples_per_class = _min_samples_per_class_for_config(config=config)
60+
if split not in dataset:
61+
return SplitReadinessResult(
62+
ready=False,
63+
underpopulated_classes=[],
64+
min_samples_per_class_required=min_samples_per_class,
65+
reason=f"Dataset has no split '{split}'.",
66+
)
67+
hf_split = dataset[split]
68+
splitter = StratifiedSplitter(
69+
test_size=config.validation_size,
70+
label_feature=dataset.label_feature,
71+
random_seed=None,
72+
)
73+
inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train)
74+
expected_n_classes = _expected_n_classes(dataset, inputs.dataset, splitter.label_feature)
75+
76+
if inputs.multilabel:
77+
underpopulated = _find_underpopulated_multilabel(inputs.dataset, splitter.label_feature, min_samples_per_class)
78+
else:
79+
underpopulated = _find_underpopulated_multiclass(
80+
inputs.dataset,
81+
splitter.label_feature,
82+
min_samples_per_class,
83+
expected_n_classes=expected_n_classes,
84+
)
85+
ready = len(underpopulated) == 0
86+
reason: str | None = None
87+
88+
if ready and (not inputs.multilabel):
89+
split_ok, split_reason = _check_multiclass_split_size_feasibility(
90+
dataset=inputs.dataset,
91+
label_feature=splitter.label_feature,
92+
test_size=inputs.test_size,
93+
expected_n_classes=expected_n_classes,
94+
)
95+
if not split_ok:
96+
ready = False
97+
reason = split_reason
98+
99+
if not ready and reason is None:
100+
parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated]
101+
reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format(
102+
min_samples_per_class, "; ".join(parts)
103+
)
104+
return SplitReadinessResult(
105+
ready=ready,
106+
underpopulated_classes=underpopulated,
107+
min_samples_per_class_required=min_samples_per_class,
108+
reason=reason,
109+
)
110+
111+
112+
def _min_samples_per_class_for_config(config: DataConfig) -> int:
113+
"""Return a recommended minimum samples-per-class for a given data config."""
114+
# Base requirement for a single stratified split.
115+
# For CV, the canonical lower bound is one example per fold.
116+
base = 2 if config.scheme == "ho" else int(config.n_folds)
117+
118+
# separation_ratio triggers an extra stratified split of the effective train
119+
# pool (e.g. decision vs scoring), so we double the requirement.
120+
factor = 1 if config.separation_ratio is None else 2
121+
return base * factor
122+
123+
124+
def _find_underpopulated_multiclass(
125+
dataset: HFDataset, label_feature: str, min_samples_per_class: int, expected_n_classes: int
126+
) -> list[ClassCount]:
127+
"""Return (label, count) for each class with fewer than min_samples_per_class samples."""
128+
labels: list[int] = dataset[label_feature]
129+
counts = Counter(labels)
130+
131+
# Ensure "missing" classes are treated as 0-count (underpopulated)
132+
result: list[ClassCount] = []
133+
for label in range(int(expected_n_classes)):
134+
n_samples = int(counts.get(label, 0))
135+
if n_samples < min_samples_per_class:
136+
result.append(ClassCount(id=int(label), n_samples=n_samples))
137+
return result
138+
139+
140+
def _find_underpopulated_multilabel(
141+
dataset: HFDataset, label_feature: str, min_samples_per_class: int
142+
) -> list[ClassCount]:
143+
"""Return (label_idx, positive_count) for each label with fewer than min_samples_per_class positives."""
144+
y = np.asarray(dataset[label_feature])
145+
_validate_multilabel_matrix(y)
146+
counts = y.sum(axis=0).astype(int)
147+
return [
148+
ClassCount(id=int(idx), n_samples=int(n_samples))
149+
for idx, n_samples in enumerate(counts)
150+
if n_samples < min_samples_per_class
151+
]
152+
153+
154+
def _check_multiclass_split_size_feasibility(
155+
dataset: HFDataset, label_feature: str, test_size: float, expected_n_classes: int
156+
) -> tuple[bool, str | None]:
157+
"""Return whether stratified train/test sizes are feasible for multiclass splits.
158+
159+
Even if each class has >=2 samples, sklearn stratified splitting can fail when
160+
the requested train/test sizes are too small to include all classes.
161+
"""
162+
labels = dataset[label_feature]
163+
n_classes = expected_n_classes
164+
n_samples = len(labels)
165+
166+
# Mirror sklearn's float test_size -> n_test calculation (ceil).
167+
n_test = int(np.ceil(float(test_size) * n_samples))
168+
n_train = n_samples - n_test
169+
170+
if n_test <= 0 or n_train <= 0:
171+
return (
172+
False,
173+
f"Requested split sizes are invalid (n_samples={n_samples}, test_size={test_size}).",
174+
)
175+
if n_test < n_classes:
176+
return (
177+
False,
178+
f"Stratified split would allocate too few test samples (n_test={n_test}) "
179+
f"for the number of classes (n_classes={n_classes}).",
180+
)
181+
if n_train < n_classes:
182+
return (
183+
False,
184+
f"Stratified split would allocate too few train samples (n_train={n_train}) "
185+
f"for the number of classes (n_classes={n_classes}).",
186+
)
187+
return True, None
188+
189+
190+
def _expected_n_classes(dataset: Dataset, prepared: HFDataset, label_feature: str) -> int:
191+
if dataset.multilabel:
192+
return len(prepared[label_feature][0])
193+
labels: list[int] = prepared[label_feature]
194+
max_seen = max(labels) if labels else -1
195+
return max(dataset.n_classes, int(max_seen) + 1)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
import numpy as np
6+
from skmultilearn.model_selection import IterativeStratification
7+
from transformers import set_seed
8+
9+
if TYPE_CHECKING:
10+
import numpy.typing as npt
11+
12+
_MULTILABEL_NDIMS = 2
13+
_RARE_LABEL_COUNT_SINGLETON = 1
14+
_RARE_LABEL_COUNT_PAIR = 2
15+
_COIN_FLIP_P = 0.5
16+
17+
18+
def safe_multilabel_split_indices(
19+
y: npt.NDArray[Any], test_size: float, random_seed: int | None
20+
) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
21+
"""Split multilabel data with coverage guarantees for rare labels."""
22+
_validate_multilabel_matrix(y)
23+
n_samples = int(y.shape[0])
24+
rng = np.random.default_rng(random_seed)
25+
26+
train_idx: set[int] = set()
27+
test_idx: set[int] = set()
28+
label_counts = y.sum(axis=0).astype(int)
29+
30+
_force_singleton_labels(y=y, label_counts=label_counts, train_idx=train_idx)
31+
_force_pair_labels(y=y, label_counts=label_counts, train_idx=train_idx, test_idx=test_idx, rng=rng)
32+
33+
forced = train_idx | test_idx
34+
remaining = np.array(sorted(set(range(n_samples)) - forced), dtype=int)
35+
_iterative_stratify_remaining(
36+
y=y,
37+
remaining=remaining,
38+
test_size=test_size,
39+
random_seed=random_seed,
40+
train_idx=train_idx,
41+
test_idx=test_idx,
42+
)
43+
return _finalize_partition(n_samples=n_samples, train_idx=train_idx, test_idx=test_idx)
44+
45+
46+
def _validate_multilabel_matrix(y: npt.NDArray[Any]) -> None:
47+
if y.ndim != _MULTILABEL_NDIMS:
48+
msg = (
49+
"Expected multilabel data to be a 2D matrix-like structure "
50+
f"(n_samples, n_labels), got shape={getattr(y, 'shape', None)!r}."
51+
)
52+
raise ValueError(msg)
53+
54+
55+
def _assigned_split(sample_idx: int, train_idx: set[int], test_idx: set[int]) -> str | None:
56+
if sample_idx in train_idx:
57+
return "train"
58+
if sample_idx in test_idx:
59+
return "test"
60+
return None
61+
62+
63+
def _force_singleton_labels(y: npt.NDArray[Any], label_counts: npt.NDArray[Any], train_idx: set[int]) -> None:
64+
for label, count in enumerate(label_counts):
65+
if int(count) != _RARE_LABEL_COUNT_SINGLETON:
66+
continue
67+
sample = int(np.flatnonzero(y[:, label])[0])
68+
train_idx.add(sample)
69+
70+
71+
def _force_pair_samples(a: int, b: int, train_idx: set[int], test_idx: set[int], rng: np.random.Generator) -> None:
72+
a_split = _assigned_split(a, train_idx, test_idx)
73+
b_split = _assigned_split(b, train_idx, test_idx)
74+
75+
if a_split is not None and b_split is None:
76+
(test_idx if a_split == "train" else train_idx).add(b)
77+
return
78+
if b_split is not None and a_split is None:
79+
(test_idx if b_split == "train" else train_idx).add(a)
80+
return
81+
if a_split is None and b_split is None:
82+
if rng.random() < _COIN_FLIP_P:
83+
train_idx.add(a)
84+
test_idx.add(b)
85+
else:
86+
train_idx.add(b)
87+
test_idx.add(a)
88+
89+
90+
def _force_pair_labels(
91+
y: npt.NDArray[Any],
92+
label_counts: npt.NDArray[Any],
93+
train_idx: set[int],
94+
test_idx: set[int],
95+
rng: np.random.Generator,
96+
) -> None:
97+
for label, count in enumerate(label_counts):
98+
if int(count) != _RARE_LABEL_COUNT_PAIR:
99+
continue
100+
samples = np.flatnonzero(y[:, label]).astype(int)
101+
a, b = sorted(samples.tolist(), key=lambda i: int(y[i].sum()))
102+
_force_pair_samples(a=a, b=b, train_idx=train_idx, test_idx=test_idx, rng=rng)
103+
104+
105+
def _iterative_stratify_remaining(
106+
y: npt.NDArray[Any],
107+
remaining: npt.NDArray[Any],
108+
test_size: float,
109+
random_seed: int | None,
110+
train_idx: set[int],
111+
test_idx: set[int],
112+
) -> None:
113+
if len(remaining) == 0:
114+
return
115+
if random_seed is not None:
116+
# Workaround for buggy nature of IterativeStratification from skmultilearn
117+
set_seed(random_seed)
118+
splitter = IterativeStratification(
119+
n_splits=2,
120+
order=2,
121+
# NOTE: IterativeStratification expects fold distribution in (test, train) order,
122+
# but returns indices as (train, test). This matches the library's behavior and
123+
# keeps backward-compatible train/test sizes with prior implementation.
124+
sample_distribution_per_fold=[test_size, 1.0 - test_size],
125+
)
126+
train_r, test_r = next(splitter.split(np.arange(len(remaining)), y[remaining]))
127+
train_idx |= set(remaining[train_r].tolist())
128+
test_idx |= set(remaining[test_r].tolist())
129+
130+
131+
def _finalize_partition(
132+
n_samples: int, train_idx: set[int], test_idx: set[int]
133+
) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
134+
train_arr = np.array(sorted(train_idx), dtype=int)
135+
test_arr = np.array(sorted(test_idx), dtype=int)
136+
137+
if len(train_arr) + len(test_arr) != n_samples:
138+
msg = (
139+
"Multilabel split did not partition all samples: "
140+
f"n_samples={n_samples}, train={len(train_arr)}, test={len(test_arr)}."
141+
)
142+
raise RuntimeError(msg)
143+
if set(train_arr.tolist()) & set(test_arr.tolist()):
144+
msg = "Multilabel split produced overlapping train/test indices."
145+
raise RuntimeError(msg)
146+
return train_arr, test_arr

0 commit comments

Comments
 (0)