From 5fbb5c0beae16caf7536faacafb4f581452081db Mon Sep 17 00:00:00 2001 From: voorhs Date: Fri, 3 Apr 2026 14:05:11 +0300 Subject: [PATCH 1/4] implement new logic --- .../context/data_handler/_data_handler.py | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/src/autointent/context/data_handler/_data_handler.py b/src/autointent/context/data_handler/_data_handler.py index b0a6e8ad..e5a30f67 100644 --- a/src/autointent/context/data_handler/_data_handler.py +++ b/src/autointent/context/data_handler/_data_handler.py @@ -176,6 +176,36 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s train_labels = [lab for lab in train_labels if lab is not None] yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc] + def _has_oos_samples(self, split_name: str) -> bool: + """Return True if the given split contains OOS (label is None) samples.""" + if split_name not in self.dataset: + return False + hf_split = self.dataset[split_name] + label_feature = self.dataset.label_feature + oos_samples = hf_split.filter(lambda sample: sample[label_feature] is None) # noqa: B023 + return len(oos_samples) > 0 + + def _duplicate_split_for_scoring_and_decision(self, split_name: str) -> None: + """Duplicate split into _0/_1 where _0 is in-domain only. + + Intended for hold-out mode when OOS is present but separation_ratio is not set: + - scoring uses `{split_name}_0` (no OOS) + - decision uses `{split_name}_1` (full, may include OOS) + """ + if split_name not in self.dataset: + return + hf_split = self.dataset[split_name] + label_feature = self.dataset.label_feature + + in_domain = hf_split.filter(lambda sample: sample[label_feature] is not None) # noqa: B023 + if len(in_domain) == 0: + msg = f"Split '{split_name}' contains only OOS samples; cannot prepare scoring split." + raise ValueError(msg) + + self.dataset[f"{split_name}_0"] = in_domain + self.dataset[f"{split_name}_1"] = hf_split + self.dataset.pop(split_name) + def _split_ho( self, separation_ratio: FloatFromZeroToOne | None, @@ -185,8 +215,16 @@ def _split_ho( ) -> None: has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset) - if separation_ratio is not None and Split.TRAIN in self.dataset: - self._split_train(separation_ratio) + if Split.TRAIN in self.dataset: + if separation_ratio is not None: + self._split_train(separation_ratio) + elif self._has_oos_samples(Split.TRAIN): + # When OOS exists and separation_ratio is not set, keep the same in-domain pool + # for scoring and decision, but exclude OOS from scoring split. + self._duplicate_split_for_scoring_and_decision(Split.TRAIN) + # If user provided a single validation split containing OOS, make scoring validation OOS-free. + if Split.VALIDATION in self.dataset and self._has_oos_samples(Split.VALIDATION): + self._duplicate_split_for_scoring_and_decision(Split.VALIDATION) if not has_validation_split: self._split_validation_from_train(validation_size, is_few_shot, examples_per_intent) From 88f689b87bcc937d12b5f1c7ebf79b5f4c9cda7a Mon Sep 17 00:00:00 2001 From: voorhs Date: Fri, 3 Apr 2026 14:07:17 +0300 Subject: [PATCH 2/4] add new tests --- tests/data/test_data_handler.py | 66 +++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/data/test_data_handler.py b/tests/data/test_data_handler.py index d1fbf049..2b02c98d 100644 --- a/tests/data/test_data_handler.py +++ b/tests/data/test_data_handler.py @@ -5,6 +5,7 @@ from autointent import Dataset from autointent.configs import DataConfig from autointent.context.data_handler import DataHandler +from autointent.custom_types import Split from autointent.schemas import Sample @@ -246,3 +247,68 @@ def test_few_shot_split(dataset): assert Counter(dh.dataset[data_split][dh.dataset.label_feature]) == desired_specs[data_split], ( f"Failed for {data_split}" ) + + +def _make_multiclass_mapping_with_oos(*, with_validation: bool) -> dict: + in_domain = [] + # Ensure enough samples per class so stratified splitting doesn't fail. + for i in range(50): + in_domain.append({"utterance": f"c0_{i}", "label": 0}) + for i in range(50): + in_domain.append({"utterance": f"c1_{i}", "label": 1}) + + oos = [{"utterance": f"oos_{i}"} for i in range(20)] + + mapping: dict = { + "train": [*in_domain, *oos], + "intents": [{"id": 0}, {"id": 1}], + } + + if with_validation: + mapping["validation"] = [ + {"utterance": "val_c0_0", "label": 0}, + {"utterance": "val_c0_1", "label": 0}, + {"utterance": "val_c1_0", "label": 1}, + {"utterance": "val_c1_1", "label": 1}, + {"utterance": "val_oos_0"}, + {"utterance": "val_oos_1"}, + ] + + return mapping + + +def _split_has_oos_labels(dh: DataHandler, split_name: str) -> bool: + return any(lab is None for lab in dh.dataset[split_name][dh.dataset.label_feature]) + + +def test_ho_oos_without_separation_ratio_duplicates_and_filters_scoring_splits(): + """If OOS exists and separation_ratio is None, scoring splits must be OOS-free.""" + dataset = Dataset.from_dict(_make_multiclass_mapping_with_oos(with_validation=False)) + dh = DataHandler(dataset, config=DataConfig(scheme="ho", separation_ratio=None), random_seed=42) + + assert "train_0" in dh.dataset + assert "train_1" in dh.dataset + assert "validation_0" in dh.dataset + assert "validation_1" in dh.dataset + assert Split.TRAIN not in dh.dataset + assert Split.VALIDATION not in dh.dataset + + assert _split_has_oos_labels(dh, "train_0") is False + assert _split_has_oos_labels(dh, "validation_0") is False + assert _split_has_oos_labels(dh, "train_1") is True + assert _split_has_oos_labels(dh, "validation_1") is True + + +def test_ho_oos_with_user_validation_duplicates_validation_when_needed(): + """If user provides validation with OOS, it should be duplicated and filtered for scoring.""" + dataset = Dataset.from_dict(_make_multiclass_mapping_with_oos(with_validation=True)) + dh = DataHandler(dataset, config=DataConfig(scheme="ho", separation_ratio=None), random_seed=42) + + assert "train_0" in dh.dataset + assert "train_1" in dh.dataset + assert "validation_0" in dh.dataset + assert "validation_1" in dh.dataset + assert Split.VALIDATION not in dh.dataset + + assert _split_has_oos_labels(dh, "validation_0") is False + assert _split_has_oos_labels(dh, "validation_1") is True From 12ec1d4c310882c5aeb9fa9b0a925c8d42ed0c1b Mon Sep 17 00:00:00 2001 From: voorhs Date: Fri, 3 Apr 2026 14:08:00 +0300 Subject: [PATCH 3/4] unused ruff ignore --- src/autointent/context/data_handler/_data_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/autointent/context/data_handler/_data_handler.py b/src/autointent/context/data_handler/_data_handler.py index e5a30f67..c3f607bc 100644 --- a/src/autointent/context/data_handler/_data_handler.py +++ b/src/autointent/context/data_handler/_data_handler.py @@ -182,7 +182,7 @@ def _has_oos_samples(self, split_name: str) -> bool: return False hf_split = self.dataset[split_name] label_feature = self.dataset.label_feature - oos_samples = hf_split.filter(lambda sample: sample[label_feature] is None) # noqa: B023 + oos_samples = hf_split.filter(lambda sample: sample[label_feature] is None) return len(oos_samples) > 0 def _duplicate_split_for_scoring_and_decision(self, split_name: str) -> None: @@ -197,7 +197,7 @@ def _duplicate_split_for_scoring_and_decision(self, split_name: str) -> None: hf_split = self.dataset[split_name] label_feature = self.dataset.label_feature - in_domain = hf_split.filter(lambda sample: sample[label_feature] is not None) # noqa: B023 + in_domain = hf_split.filter(lambda sample: sample[label_feature] is not None) if len(in_domain) == 0: msg = f"Split '{split_name}' contains only OOS samples; cannot prepare scoring split." raise ValueError(msg) From a9b7da58fd86accc81d9605fc775a389ba7fdc48 Mon Sep 17 00:00:00 2001 From: voorhs Date: Fri, 3 Apr 2026 14:08:30 +0300 Subject: [PATCH 4/4] fix ruff --- tests/data/test_data_handler.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/data/test_data_handler.py b/tests/data/test_data_handler.py index 2b02c98d..c441428e 100644 --- a/tests/data/test_data_handler.py +++ b/tests/data/test_data_handler.py @@ -250,12 +250,10 @@ def test_few_shot_split(dataset): def _make_multiclass_mapping_with_oos(*, with_validation: bool) -> dict: - in_domain = [] # Ensure enough samples per class so stratified splitting doesn't fail. - for i in range(50): - in_domain.append({"utterance": f"c0_{i}", "label": 0}) - for i in range(50): - in_domain.append({"utterance": f"c1_{i}", "label": 1}) + in_domain = [{"utterance": f"c0_{i}", "label": 0} for i in range(50)] + [ + {"utterance": f"c1_{i}", "label": 1} for i in range(50) + ] oos = [{"utterance": f"oos_{i}"} for i in range(20)]