Skip to content

Commit ac9eab4

Browse files
authored
F/always separate oos (#280)
* implement new logic * add new tests * unused ruff ignore * fix ruff
1 parent 9c4c534 commit ac9eab4

2 files changed

Lines changed: 104 additions & 2 deletions

File tree

src/autointent/context/data_handler/_data_handler.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,36 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
176176
train_labels = [lab for lab in train_labels if lab is not None]
177177
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]
178178

179+
def _has_oos_samples(self, split_name: str) -> bool:
180+
"""Return True if the given split contains OOS (label is None) samples."""
181+
if split_name not in self.dataset:
182+
return False
183+
hf_split = self.dataset[split_name]
184+
label_feature = self.dataset.label_feature
185+
oos_samples = hf_split.filter(lambda sample: sample[label_feature] is None)
186+
return len(oos_samples) > 0
187+
188+
def _duplicate_split_for_scoring_and_decision(self, split_name: str) -> None:
189+
"""Duplicate split into _0/_1 where _0 is in-domain only.
190+
191+
Intended for hold-out mode when OOS is present but separation_ratio is not set:
192+
- scoring uses `{split_name}_0` (no OOS)
193+
- decision uses `{split_name}_1` (full, may include OOS)
194+
"""
195+
if split_name not in self.dataset:
196+
return
197+
hf_split = self.dataset[split_name]
198+
label_feature = self.dataset.label_feature
199+
200+
in_domain = hf_split.filter(lambda sample: sample[label_feature] is not None)
201+
if len(in_domain) == 0:
202+
msg = f"Split '{split_name}' contains only OOS samples; cannot prepare scoring split."
203+
raise ValueError(msg)
204+
205+
self.dataset[f"{split_name}_0"] = in_domain
206+
self.dataset[f"{split_name}_1"] = hf_split
207+
self.dataset.pop(split_name)
208+
179209
def _split_ho(
180210
self,
181211
separation_ratio: FloatFromZeroToOne | None,
@@ -185,8 +215,16 @@ def _split_ho(
185215
) -> None:
186216
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
187217

188-
if separation_ratio is not None and Split.TRAIN in self.dataset:
189-
self._split_train(separation_ratio)
218+
if Split.TRAIN in self.dataset:
219+
if separation_ratio is not None:
220+
self._split_train(separation_ratio)
221+
elif self._has_oos_samples(Split.TRAIN):
222+
# When OOS exists and separation_ratio is not set, keep the same in-domain pool
223+
# for scoring and decision, but exclude OOS from scoring split.
224+
self._duplicate_split_for_scoring_and_decision(Split.TRAIN)
225+
# If user provided a single validation split containing OOS, make scoring validation OOS-free.
226+
if Split.VALIDATION in self.dataset and self._has_oos_samples(Split.VALIDATION):
227+
self._duplicate_split_for_scoring_and_decision(Split.VALIDATION)
190228

191229
if not has_validation_split:
192230
self._split_validation_from_train(validation_size, is_few_shot, examples_per_intent)

tests/data/test_data_handler.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from autointent import Dataset
66
from autointent.configs import DataConfig
77
from autointent.context.data_handler import DataHandler
8+
from autointent.custom_types import Split
89
from autointent.schemas import Sample
910

1011

@@ -246,3 +247,66 @@ def test_few_shot_split(dataset):
246247
assert Counter(dh.dataset[data_split][dh.dataset.label_feature]) == desired_specs[data_split], (
247248
f"Failed for {data_split}"
248249
)
250+
251+
252+
def _make_multiclass_mapping_with_oos(*, with_validation: bool) -> dict:
253+
# Ensure enough samples per class so stratified splitting doesn't fail.
254+
in_domain = [{"utterance": f"c0_{i}", "label": 0} for i in range(50)] + [
255+
{"utterance": f"c1_{i}", "label": 1} for i in range(50)
256+
]
257+
258+
oos = [{"utterance": f"oos_{i}"} for i in range(20)]
259+
260+
mapping: dict = {
261+
"train": [*in_domain, *oos],
262+
"intents": [{"id": 0}, {"id": 1}],
263+
}
264+
265+
if with_validation:
266+
mapping["validation"] = [
267+
{"utterance": "val_c0_0", "label": 0},
268+
{"utterance": "val_c0_1", "label": 0},
269+
{"utterance": "val_c1_0", "label": 1},
270+
{"utterance": "val_c1_1", "label": 1},
271+
{"utterance": "val_oos_0"},
272+
{"utterance": "val_oos_1"},
273+
]
274+
275+
return mapping
276+
277+
278+
def _split_has_oos_labels(dh: DataHandler, split_name: str) -> bool:
279+
return any(lab is None for lab in dh.dataset[split_name][dh.dataset.label_feature])
280+
281+
282+
def test_ho_oos_without_separation_ratio_duplicates_and_filters_scoring_splits():
283+
"""If OOS exists and separation_ratio is None, scoring splits must be OOS-free."""
284+
dataset = Dataset.from_dict(_make_multiclass_mapping_with_oos(with_validation=False))
285+
dh = DataHandler(dataset, config=DataConfig(scheme="ho", separation_ratio=None), random_seed=42)
286+
287+
assert "train_0" in dh.dataset
288+
assert "train_1" in dh.dataset
289+
assert "validation_0" in dh.dataset
290+
assert "validation_1" in dh.dataset
291+
assert Split.TRAIN not in dh.dataset
292+
assert Split.VALIDATION not in dh.dataset
293+
294+
assert _split_has_oos_labels(dh, "train_0") is False
295+
assert _split_has_oos_labels(dh, "validation_0") is False
296+
assert _split_has_oos_labels(dh, "train_1") is True
297+
assert _split_has_oos_labels(dh, "validation_1") is True
298+
299+
300+
def test_ho_oos_with_user_validation_duplicates_validation_when_needed():
301+
"""If user provides validation with OOS, it should be duplicated and filtered for scoring."""
302+
dataset = Dataset.from_dict(_make_multiclass_mapping_with_oos(with_validation=True))
303+
dh = DataHandler(dataset, config=DataConfig(scheme="ho", separation_ratio=None), random_seed=42)
304+
305+
assert "train_0" in dh.dataset
306+
assert "train_1" in dh.dataset
307+
assert "validation_0" in dh.dataset
308+
assert "validation_1" in dh.dataset
309+
assert Split.VALIDATION not in dh.dataset
310+
311+
assert _split_has_oos_labels(dh, "validation_0") is False
312+
assert _split_has_oos_labels(dh, "validation_1") is True

0 commit comments

Comments
 (0)