Skip to content

Commit 083ee5c

Browse files
committed
Merge branch 'dev' into feat/lazy-import-heavy-dependencies
# Conflicts: # pyproject.toml # src/autointent/context/data_handler/_stratification.py
2 parents 1c9eb89 + 1124fe4 commit 083ee5c

11 files changed

Lines changed: 630 additions & 127 deletions

docs/optimizer_config.schema.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
"type": "null"
114114
}
115115
],
116-
"default": 0.5,
116+
"default": null,
117117
"description": "Set to float to prevent data leak between scoring and decision nodes.",
118118
"title": "Separation Ratio"
119119
},
@@ -498,7 +498,7 @@
498498
"scheme": "ho",
499499
"n_folds": 3,
500500
"validation_size": 0.2,
501-
"separation_ratio": 0.5,
501+
"separation_ratio": null,
502502
"is_few_shot_train": false,
503503
"examples_per_intent": 8
504504
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ opensearch = [
7575
"opensearch-py (>=3.0.0, <4.0.0)",
7676
]
7777
openai = [
78-
"openai (>=1.59.6,<2.0.0)",
78+
"openai (>=2,<3)",
7979
]
8080

8181
[tool.uv]

src/autointent/_optimization_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
44

55
from pydantic import BaseModel, Field, PositiveInt, field_validator
66

@@ -14,6 +14,10 @@
1414
get_default_hfmodel_config,
1515
initialize_embedder_config,
1616
)
17+
from .utils import load_preset
18+
19+
if TYPE_CHECKING:
20+
from .custom_types import SearchSpacePreset
1721

1822

1923
class OptimizationConfig(BaseModel):
@@ -46,3 +50,7 @@ def validate_embedder_config(cls, v: Any) -> EmbedderConfig: # noqa: ANN401
4650
hpo_config: HPOConfig = HPOConfig()
4751

4852
seed: PositiveInt = 42
53+
54+
@classmethod
55+
def from_preset(cls, preset: SearchSpacePreset) -> OptimizationConfig:
56+
return cls.model_validate(load_preset(preset))

src/autointent/configs/_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class DataConfig(BaseModel):
2929
)
3030
"""Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."""
3131
separation_ratio: FloatFromZeroToOne | None = Field(
32-
0.5, description="Set to float to prevent data leak between scoring and decision nodes."
32+
None, description="Set to float to prevent data leak between scoring and decision nodes."
3333
)
3434
"""Set to float to prevent data leak between scoring and decision nodes."""
3535
is_few_shot_train: bool = Field(False, description="Whether to use few-shot training.")

src/autointent/context/data_handler/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
from ._data_handler import DataHandler
2-
from ._stratification import (
3-
SplitReadinessResult,
4-
StratifiedSplitter,
5-
check_split_readiness,
6-
split_dataset,
7-
)
2+
from ._readiness_util import SplitReadinessResult, check_split_readiness
3+
from ._stratification import StratifiedSplitter, split_dataset
84

95
__all__ = [
106
"DataHandler",
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from __future__ import annotations
2+
3+
from collections import Counter
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, NamedTuple
6+
7+
import numpy as np
8+
9+
if TYPE_CHECKING:
10+
from datasets import Dataset as HFDataset
11+
12+
from autointent import Dataset
13+
from autointent.configs import DataConfig
14+
15+
from ._safe_multilabel_stratification import _validate_multilabel_matrix
16+
from ._stratification import StratifiedSplitter
17+
18+
19+
class ClassCount(NamedTuple):
20+
id: int
21+
"""Class (intent) index."""
22+
23+
n_samples: int
24+
"""Number of samples from the class (intent)."""
25+
26+
27+
@dataclass(frozen=True)
28+
class SplitReadinessResult:
29+
"""Result of checking whether a dataset can be fed to autointent pipeline.
30+
31+
Attributes:
32+
ready: True if stratification can be performed (enough samples per class).
33+
underpopulated_classes: List of (label, n_samples) for classes below the minimum.
34+
min_samples_per_class_required: Minimum samples per class used for the check.
35+
reason: Human-readable reason when not ready (e.g. OOS not configured).
36+
"""
37+
38+
ready: bool
39+
underpopulated_classes: list[ClassCount]
40+
min_samples_per_class_required: int
41+
reason: str | None
42+
43+
44+
def check_split_readiness(
45+
dataset: Dataset,
46+
split: str,
47+
config: DataConfig,
48+
allow_oos_in_train: bool | None = None,
49+
) -> SplitReadinessResult:
50+
"""Check whether the dataset has enough samples per class for autointent pipeline.
51+
52+
Args:
53+
dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`).
54+
split: The split name to check (e.g. ``Split.TRAIN``).
55+
config: data config
56+
allow_oos_in_train: Same as in :func:`split_dataset`. If the split contains OOS samples
57+
and this is ``None``, this function raises ``ValueError`` (mirrors splitting behavior).
58+
"""
59+
min_samples_per_class = _min_samples_per_class_for_config(config=config)
60+
if split not in dataset:
61+
return SplitReadinessResult(
62+
ready=False,
63+
underpopulated_classes=[],
64+
min_samples_per_class_required=min_samples_per_class,
65+
reason=f"Dataset has no split '{split}'.",
66+
)
67+
hf_split = dataset[split]
68+
splitter = StratifiedSplitter(
69+
test_size=config.validation_size,
70+
label_feature=dataset.label_feature,
71+
random_seed=None,
72+
)
73+
inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train)
74+
expected_n_classes = _expected_n_classes(dataset, inputs.dataset, splitter.label_feature)
75+
76+
if inputs.multilabel:
77+
underpopulated = _find_underpopulated_multilabel(inputs.dataset, splitter.label_feature, min_samples_per_class)
78+
else:
79+
underpopulated = _find_underpopulated_multiclass(
80+
inputs.dataset,
81+
splitter.label_feature,
82+
min_samples_per_class,
83+
expected_n_classes=expected_n_classes,
84+
)
85+
ready = len(underpopulated) == 0
86+
reason: str | None = None
87+
88+
if ready and (not inputs.multilabel):
89+
split_ok, split_reason = _check_multiclass_split_size_feasibility(
90+
dataset=inputs.dataset,
91+
label_feature=splitter.label_feature,
92+
test_size=inputs.test_size,
93+
expected_n_classes=expected_n_classes,
94+
)
95+
if not split_ok:
96+
ready = False
97+
reason = split_reason
98+
99+
if not ready and reason is None:
100+
parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated]
101+
reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format(
102+
min_samples_per_class, "; ".join(parts)
103+
)
104+
return SplitReadinessResult(
105+
ready=ready,
106+
underpopulated_classes=underpopulated,
107+
min_samples_per_class_required=min_samples_per_class,
108+
reason=reason,
109+
)
110+
111+
112+
def _min_samples_per_class_for_config(config: DataConfig) -> int:
113+
"""Return a recommended minimum samples-per-class for a given data config."""
114+
# Base requirement for a single stratified split.
115+
# For CV, the canonical lower bound is one example per fold.
116+
base = 2 if config.scheme == "ho" else int(config.n_folds)
117+
118+
# separation_ratio triggers an extra stratified split of the effective train
119+
# pool (e.g. decision vs scoring), so we double the requirement.
120+
factor = 1 if config.separation_ratio is None else 2
121+
return base * factor
122+
123+
124+
def _find_underpopulated_multiclass(
125+
dataset: HFDataset, label_feature: str, min_samples_per_class: int, expected_n_classes: int
126+
) -> list[ClassCount]:
127+
"""Return (label, count) for each class with fewer than min_samples_per_class samples."""
128+
labels: list[int] = dataset[label_feature]
129+
counts = Counter(labels)
130+
131+
# Ensure "missing" classes are treated as 0-count (underpopulated)
132+
result: list[ClassCount] = []
133+
for label in range(int(expected_n_classes)):
134+
n_samples = int(counts.get(label, 0))
135+
if n_samples < min_samples_per_class:
136+
result.append(ClassCount(id=int(label), n_samples=n_samples))
137+
return result
138+
139+
140+
def _find_underpopulated_multilabel(
141+
dataset: HFDataset, label_feature: str, min_samples_per_class: int
142+
) -> list[ClassCount]:
143+
"""Return (label_idx, positive_count) for each label with fewer than min_samples_per_class positives."""
144+
y = np.asarray(dataset[label_feature])
145+
_validate_multilabel_matrix(y)
146+
counts = y.sum(axis=0).astype(int)
147+
return [
148+
ClassCount(id=int(idx), n_samples=int(n_samples))
149+
for idx, n_samples in enumerate(counts)
150+
if n_samples < min_samples_per_class
151+
]
152+
153+
154+
def _check_multiclass_split_size_feasibility(
155+
dataset: HFDataset, label_feature: str, test_size: float, expected_n_classes: int
156+
) -> tuple[bool, str | None]:
157+
"""Return whether stratified train/test sizes are feasible for multiclass splits.
158+
159+
Even if each class has >=2 samples, sklearn stratified splitting can fail when
160+
the requested train/test sizes are too small to include all classes.
161+
"""
162+
labels = dataset[label_feature]
163+
n_classes = expected_n_classes
164+
n_samples = len(labels)
165+
166+
# Mirror sklearn's float test_size -> n_test calculation (ceil).
167+
n_test = int(np.ceil(float(test_size) * n_samples))
168+
n_train = n_samples - n_test
169+
170+
if n_test <= 0 or n_train <= 0:
171+
return (
172+
False,
173+
f"Requested split sizes are invalid (n_samples={n_samples}, test_size={test_size}).",
174+
)
175+
if n_test < n_classes:
176+
return (
177+
False,
178+
f"Stratified split would allocate too few test samples (n_test={n_test}) "
179+
f"for the number of classes (n_classes={n_classes}).",
180+
)
181+
if n_train < n_classes:
182+
return (
183+
False,
184+
f"Stratified split would allocate too few train samples (n_train={n_train}) "
185+
f"for the number of classes (n_classes={n_classes}).",
186+
)
187+
return True, None
188+
189+
190+
def _expected_n_classes(dataset: Dataset, prepared: HFDataset, label_feature: str) -> int:
191+
if dataset.multilabel:
192+
return len(prepared[label_feature][0])
193+
labels: list[int] = prepared[label_feature]
194+
max_seen = max(labels) if labels else -1
195+
return max(dataset.n_classes, int(max_seen) + 1)

0 commit comments

Comments
 (0)