diff --git a/ax/adapter/registry.py b/ax/adapter/registry.py index 52b14d3a6ba..0619c35c41e 100644 --- a/ax/adapter/registry.py +++ b/ax/adapter/registry.py @@ -139,6 +139,7 @@ ] Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY, StandardizeY] +TL_Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY] # Expected `List[Type[Transform]]` for 2nd anonymous parameter to # call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`. diff --git a/ax/adapter/torch.py b/ax/adapter/torch.py index 63568ff66e2..8383d65b120 100644 --- a/ax/adapter/torch.py +++ b/ax/adapter/torch.py @@ -774,6 +774,7 @@ def _get_fit_args( search_space: SearchSpace, experiment_data: ExperimentData, update_outcomes_and_parameters: bool, + data_parameters: list[str] | None = None, ) -> tuple[ list[SupervisedDataset], list[list[TCandidateMetadata]] | None, @@ -791,6 +792,11 @@ def _get_fit_args( update_outcomes_and_parameters: Whether to update `self.outcomes` with all outcomes found in the observations and `self.parameters` with all parameters in the search space. Typically only used in `_fit`. + data_parameters: When provided, columns to extract from + ``experiment_data``. Defaults to ``self.parameters``. Useful when + the model space is larger than the data (e.g. transfer learning + with heterogeneous search spaces where the data only contains + target columns but the model operates in a joint feature space). Returns: The datasets & metadata, extracted from the ``experiment_data``, and the @@ -818,12 +824,23 @@ def _get_fit_args( search_space_digest = extract_search_space_digest( search_space=search_space, param_names=self.parameters ) + extract_params = ( + data_parameters if data_parameters is not None else self.parameters + ) + # When data_parameters differs from self.parameters, the SSD's + # task_feature indices don't match the data columns. Pass None + # to skip MultiTaskDataset wrapping (the caller handles it). + extraction_ssd = ( + None + if data_parameters is not None and data_parameters != self.parameters + else search_space_digest + ) # Convert observations to datasets datasets, ordered_outcomes, candidate_metadata = self._convert_experiment_data( experiment_data=experiment_data, outcomes=self.outcomes, - parameters=self.parameters, - search_space_digest=search_space_digest, + parameters=extract_params, + search_space_digest=extraction_ssd, ) datasets = self._update_w_aux_exp_datasets(datasets=datasets) diff --git a/ax/adapter/transfer_learning/adapter.py b/ax/adapter/transfer_learning/adapter.py index 92585add888..83d93e6bb11 100644 --- a/ax/adapter/transfer_learning/adapter.py +++ b/ax/adapter/transfer_learning/adapter.py @@ -7,7 +7,6 @@ from __future__ import annotations -import dataclasses import warnings from collections.abc import Mapping, Sequence from logging import Logger @@ -23,7 +22,7 @@ Generators, GeneratorSetup, MBM_X_trans, - Y_trans, + TL_Y_trans, ) from ax.adapter.torch import FIT_MODEL_ERROR, TorchAdapter from ax.adapter.transfer_learning.utils import get_joint_search_space @@ -39,7 +38,7 @@ from ax.core.observation import ObservationData, ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import FixedParameter, RangeParameter -from ax.core.search_space import SearchSpace, SearchSpaceDigest +from ax.core.search_space import SearchSpace from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.generation_strategy.best_model_selector import ( ReductionCriterion, @@ -54,6 +53,7 @@ from ax.utils.common.logger import get_logger from botorch.models.multitask import MultiTaskGP from botorch.models.transforms.input import InputTransform, Normalize +from botorch.models.transforms.outcome import StratifiedStandardize from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from gpytorch.kernels.kernel import Kernel from pyre_extensions import assert_is_instance @@ -93,6 +93,7 @@ def __init__( ``experiment.auxiliary_experiments_by_purpose`` with type set to ``AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT``. """ + self._source_only_params: set[str] = set() transforms = [] if transforms is None else list(transforms) if MetadataToTask not in transforms: raise UserInputError( @@ -150,22 +151,7 @@ def __init__( target_search_space=search_space, ) - # Add source-only backfilled params as FixedParameter to the target - # search space so that the compatibility check passes and the model - # space includes these params (FixedToTunable will later convert them - # to RangeParameter using the joint space bounds). search_space = search_space.clone() # avoid mutating caller's object - target_param_names = set(search_space.parameters.keys()) - for name, param in self.joint_search_space.parameters.items(): - if name not in target_param_names and param.backfill_value is not None: - search_space.add_parameter( - FixedParameter( - name=name, - parameter_type=param.parameter_type, - value=param.backfill_value, - ) - ) - # Include backfill param names in filled_params so Phase 1 of # check_search_space_compatibility passes for target-only params. filled_params.extend(self.joint_search_space.backfill_values().keys()) @@ -196,6 +182,28 @@ def __init__( default_model_gen_options=default_model_gen_options, ) + def _set_search_space(self, search_space: SearchSpace) -> None: + """Set search space and model space for transfer learning. + + Overrides the base class to add source-only params (as RangeParameters) + to ``_model_space`` while preserving target bounds for shared params. + This ensures the SSD naturally covers the full joint feature space + without needing post-hoc expansion, and Normalize is anchored to target + bounds so target data maps to [0, 1]. + """ + self._search_space = search_space.clone() + model_space = search_space.clone() + self._source_only_params = set() + for name, param in self.joint_search_space.parameters.items(): + if name not in model_space.parameters and isinstance(param, RangeParameter): + model_space.add_parameter(param.clone()) + # Only mark as source-only if no backfill value exists. + # Backfilled params have known values (via FillMissingParameters) + # and should be included in data extraction. + if param.backfill_value is None: + self._source_only_params.add(name) + self._model_space = model_space + def _transform_data( self, experiment_data: ExperimentData, @@ -505,94 +513,18 @@ def _get_task_datasets( ) return task_datasets - def _expand_ssd_to_joint_space( + def _get_target_data_parameters( self, - search_space_digest: SearchSpaceDigest, - ) -> SearchSpaceDigest: - """Expand SSD bounds and feature_names to cover the joint search space. - - The SSD produced by ``_get_fit_args`` reflects the target search space. - When source experiments have additional parameters, the model operates - in the full joint feature space. This method appends bounds and feature - names for source-only parameters so that input transforms receive - correct full-space bounds. + all_params: list[str], + ) -> list[str]: + """Filter a joint parameter list to target-only params + task feature. + + Source-only params (those added by ``_set_search_space`` from the joint + space) are excluded because the target experiment data does not have + those columns. Uses untransformed names, which are stable across + transforms (Range params are never renamed by OneHot, IntToFloat, etc.). """ - existing_names = set(search_space_digest.feature_names) - extra_names: list[str] = [] - extra_bounds: list[tuple[int | float, int | float]] = [] - # Only collect parameters absent from the target SSD. Shared - # parameters that appear in both target and source keep the target - # bounds -- source observations outside those bounds will normalize - # outside [0, 1]. This is intentional, as the GP hyperprior is calibrated - # for a __target__ task in [0, 1]^D. - for name, param in self.joint_search_space.parameters.items(): - if name not in existing_names and isinstance(param, RangeParameter): - extra_names.append(name) - extra_bounds.append((param.lower, param.upper)) - if not extra_names: - return search_space_digest - # Insert source-only params before the task feature - task_features = search_space_digest.task_features - if len(task_features) == 1: - tf_idx = task_features[0] - names = list(search_space_digest.feature_names) - bounds = list(search_space_digest.bounds) - # Raise if index-based fields (other than the task feature - # itself) reference indices at or above tf_idx, since we would - # need to shift them when inserting extra params. - for field_name in ( - "ordinal_features", - "categorical_features", - "fidelity_features", - ): - indices = getattr(search_space_digest, field_name) - if any(i >= tf_idx for i in indices): - raise UnsupportedError( - f"Cannot expand SSD: {field_name} contains index >= {tf_idx}." - ) - if any( - i >= tf_idx and i not in task_features - for i in search_space_digest.discrete_choices - ): - raise UnsupportedError( - f"Cannot expand SSD: discrete_choices contains index >= {tf_idx}." - ) - if search_space_digest.hierarchical_dependencies is not None and any( - i >= tf_idx for i in search_space_digest.hierarchical_dependencies - ): - raise UnsupportedError( - "Cannot expand SSD: hierarchical_dependencies contains " - f"index >= {tf_idx}." - ) - names[tf_idx:tf_idx] = extra_names - bounds[tf_idx:tf_idx] = extra_bounds - n_extra = len(extra_names) - new_task_features = [tf_idx + n_extra] - new_target_values = dict(search_space_digest.target_values) - if tf_idx in new_target_values: - new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx) - new_discrete = dict(search_space_digest.discrete_choices) - if tf_idx in new_discrete: - new_discrete[new_task_features[0]] = new_discrete.pop(tf_idx) - return dataclasses.replace( - search_space_digest, - feature_names=names, - bounds=bounds, - task_features=new_task_features, - target_values=new_target_values, - discrete_choices=new_discrete, - ) - elif len(task_features) == 0: - # No task feature -- just append. - return dataclasses.replace( - search_space_digest, - feature_names=search_space_digest.feature_names + extra_names, - bounds=search_space_digest.bounds + extra_bounds, - ) - else: - raise UnsupportedError( - "Multiple task features are not supported in transfer learning." - ) + return [p for p in all_params if p not in self._source_only_params] def _fit( self, @@ -610,15 +542,20 @@ def _fit( if experiment_data.arm_data.empty: # Temporarily unset self.outcomes to avoid an error in _get_fit_args. self.outcomes = [] + # Pre-compute the joint param ordering (mirrors _get_fit_args logic) + # so we can derive the target-only subset for data extraction. + all_params = list(search_space.parameters.keys()) + task_name = Keys.TASK_FEATURE_NAME.value + if task_name in all_params: + all_params.remove(task_name) + all_params.append(task_name) + target_data_params = self._get_target_data_parameters(all_params) datasets, candidate_metadata, search_space_digest = self._get_fit_args( search_space=search_space, experiment_data=experiment_data, update_outcomes_and_parameters=True, + data_parameters=target_data_params, ) - # Expand SSD bounds to cover source-only params from the joint search - # space. This ensures Normalize (and other input transforms) get bounds - # for the full feature space, not just the target dims. - search_space_digest = self._expand_ssd_to_joint_space(search_space_digest) if experiment_data.arm_data.empty: self.outcomes = outcomes # Temporarily set datasets to None. We will construct empty datasets @@ -656,12 +593,13 @@ def _cross_validate( ) -> list[ObservationData]: if self.parameters is None: raise ValueError(FIT_MODEL_ERROR.format(action="_cross_validate")) + target_data_params = self._get_target_data_parameters(self.parameters) datasets, _, search_space_digest = self._get_fit_args( search_space=search_space, experiment_data=cv_training_data, update_outcomes_and_parameters=False, + data_parameters=target_data_params, ) - search_space_digest = self._expand_ssd_to_joint_space(search_space_digest) # Add the task feature to SSD, to ensure that a multi-task model is selected. if len(search_space_digest.task_features) > 1: raise UnsupportedError( @@ -714,9 +652,8 @@ def gen( to ``RemoveFixed.untransform_observation_features``, which requires updating the signature of all transforms. """ - # If a fixed parameter in the target search space, is - # a range parameter in the joint search space, then we - # should set it as a fixed feature here. + # If a fixed parameter in the target search space is a range + # parameter in the joint search space, pin it as a fixed feature. search_space = search_space or self._search_space for name, target_p in search_space.parameters.items(): if ( @@ -724,13 +661,13 @@ def gen( and (p := self.joint_search_space.parameters.get(name)) is not None and isinstance(p, RangeParameter) ): - # add to fixed features if fixed_features is None: fixed_features = ObservationFeatures(parameters={}) fixed_features.parameters.setdefault(name, target_p.value) - # Fix source-only params so the optimizer doesn't search over them. - # Center is a reasonable default; LearnedFeatureImputation overwrites - # these with learned values when configured. + # Fix source-only params that ARE in the search space (e.g. injected + # as FixedParam with a backfill value) so the optimizer doesn't search + # over them. Params NOT in the search space are handled by the model + # internally (HeterogeneousMTGP natively, LFI for MultiTaskGP). joint_center = self.joint_search_space.compute_naive_center() for name, param in self.joint_search_space.parameters.items(): if name not in search_space.parameters and isinstance( @@ -739,6 +676,20 @@ def gen( if fixed_features is None: fixed_features = ObservationFeatures(parameters={}) fixed_features.parameters.setdefault(name, joint_center[name]) + # At gen time, optimize over the target search space only. + # _source_only_params covers params without backfill (stable + # untransformed names). Backfilled source-only params are identified + # by checking joint SS backfill_values — their names are also stable + # (RangeParameters are never renamed by transforms). + saved_parameters = self.parameters + backfilled_source_only = { + name + for name, param in self.joint_search_space.parameters.items() + if name not in self._experiment.search_space.parameters + and param.backfill_value is not None + } + exclude = self._source_only_params | backfilled_source_only + self.parameters = [p for p in self.parameters if p not in exclude] generator_run = super().gen( n=n, search_space=search_space, @@ -747,6 +698,7 @@ def gen( fixed_features=fixed_features, model_gen_options=model_gen_options, ) + self.parameters = saved_parameters # Remove the parameters that are not in the target experiment's search # space, and update candidate_metadata_by_arm_signature to reflect the # new arm. We use the experiment's search space rather than @@ -793,7 +745,7 @@ def transfer_learning_generator_specs_constructor( Args: model_class: The MultiTask BoTorch Model to use in the BOTL. transform: Optional list of transforms to use in the Adapter. - Defaults to MBM_X_trans + [MetadataToTask] + Y_trans. + Defaults to MBM_X_trans + [MetadataToTask] + TL_Y_trans. jit_compile: Whether to use jit compilation in Pyro when the fully Bayesian model is used. torch_device: What torch device to use (defaults to None, i.e. falls back to @@ -828,7 +780,7 @@ def transfer_learning_generator_specs_constructor( input_transform_options: dict[str, dict[str, Any]] = { "Normalize": {}, } - transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans + transforms = transforms or MBM_X_trans + [MetadataToTask] + TL_Y_trans transform_configs = get_derelativize_config( derelativize_with_raw_status_quo=derelativize_with_raw_status_quo ) @@ -846,6 +798,7 @@ def transfer_learning_generator_specs_constructor( botorch_model_class=model_class, model_options=botorch_model_kwargs or {}, input_transform_classes=input_transform_classes, + outcome_transform_classes=[StratifiedStandardize], input_transform_options=input_transform_options, mll_options=mll_kwargs, covar_module_class=covar_module_class, @@ -887,5 +840,5 @@ def transfer_learning_generator_specs_constructor( GENERATOR_KEY_TO_GENERATOR_SETUP["BOTL"] = GeneratorSetup( adapter_class=TransferLearningAdapter, generator_class=BoTorchGenerator, - transforms=MBM_X_trans + [MetadataToTask] + Y_trans, + transforms=MBM_X_trans + [MetadataToTask] + TL_Y_trans, ) diff --git a/ax/adapter/transfer_learning/tests/test_adapter.py b/ax/adapter/transfer_learning/tests/test_adapter.py index 30902bbad7a..9e4ec8c473a 100644 --- a/ax/adapter/transfer_learning/tests/test_adapter.py +++ b/ax/adapter/transfer_learning/tests/test_adapter.py @@ -5,113 +5,194 @@ # pyre-strict -from unittest.mock import MagicMock, PropertyMock +from __future__ import annotations -from ax.adapter.transfer_learning.adapter import TransferLearningAdapter +from unittest.mock import MagicMock, patch + +import torch +from ax.adapter.transfer_learning.adapter import TL_EXP, TransferLearningAdapter +from ax.adapter.transforms.metadata_to_task import MetadataToTask +from ax.core.arm import Arm +from ax.core.auxiliary_source import AuxiliarySource +from ax.core.experiment import Experiment from ax.core.parameter import ParameterType, RangeParameter -from ax.core.search_space import SearchSpace, SearchSpaceDigest -from ax.exceptions.core import UnsupportedError +from ax.core.search_space import SearchSpace +from ax.generators.torch.botorch_modular.generator import BoTorchGenerator +from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_experiment_with_observations +from ax.utils.testing.mock import mock_botorch_optimize +from pyre_extensions import assert_is_instance, none_throws -class ExpandSsdToJointSpaceTest(TestCase): - def setUp(self) -> None: - super().setUp() - self.adapter = MagicMock(spec=TransferLearningAdapter) - - def _make_joint_ss(self, params: dict[str, tuple[float, float]]) -> SearchSpace: - return SearchSpace( - parameters=[ - RangeParameter( - name=n, - lower=lo, - upper=hi, - parameter_type=ParameterType.FLOAT, - ) - for n, (lo, hi) in params.items() - ] - ) +def _make_ss(params: dict[str, tuple[float, float]]) -> SearchSpace: + return SearchSpace( + parameters=[ + RangeParameter( + name=n, + lower=lo, + upper=hi, + parameter_type=ParameterType.FLOAT, + ) + for n, (lo, hi) in params.items() + ] + ) - def test_no_extra_params_returns_unchanged(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1)}) - ) - ssd = SearchSpaceDigest( - feature_names=["x1", "x2", "task"], - bounds=[(0, 1), (0, 1), (0, 2)], - task_features=[2], - target_values={2: 0}, - ) - result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) - self.assertIs(result, ssd) - def test_single_task_feature_inserts_before_task(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss( - {"x1": (0, 1), "x2": (0, 1), "x3": (-2, 5)} - ) - ) - ssd = SearchSpaceDigest( - feature_names=["x1", "x2", "task"], - bounds=[(0, 1), (0, 1), (0, 2)], - task_features=[2], - target_values={2: 0}, - ) - result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) - self.assertEqual(result.feature_names, ["x1", "x2", "x3", "task"]) - self.assertEqual(result.bounds, [(0, 1), (0, 1), (-2, 5), (0, 2)]) - self.assertEqual(result.task_features, [3]) - self.assertEqual(result.target_values, {3: 0}) - - def test_zero_task_features_appends(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (-1, 3)}) - ) - ssd = SearchSpaceDigest( - feature_names=["x1"], - bounds=[(0, 1)], - ) - result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) - self.assertEqual(result.feature_names, ["x1", "x2"]) - self.assertEqual(result.bounds, [(0, 1), (-1, 3)]) +def _gen_experiment( + experiment_name: str, + num_trials: int, + search_space: SearchSpace | None = None, +) -> Experiment: + exp = get_experiment_with_observations( + observations=torch.rand(num_trials, 1).tolist(), + search_space=search_space, + ) + exp.name = experiment_name + return exp - def test_discrete_choices_on_task_feature_shifted(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)}) + +class SetSearchSpaceTest(TestCase): + """_set_search_space adds source-only params to _model_space while + preserving target bounds for shared params.""" + + def test_model_space_has_source_only_params(self) -> None: + target_ss = _make_ss({"x": (0, 1), "y": (0, 1)}) + source_ss = _make_ss({"x": (0, 5), "y": (0, 5), "z": (0, 5)}) + target_exp = _gen_experiment("target", num_trials=3, search_space=target_ss) + source_exp = _gen_experiment("source", num_trials=3, search_space=source_ss) + source_exp.status_quo = Arm(parameters={"x": 1.0, "y": 1.0, "z": 2.5}) + target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [ + AuxiliarySource(experiment=source_exp) + ] + adapter = TransferLearningAdapter( + experiment=target_exp, + search_space=target_ss, + data=target_exp.lookup_data(), + generator=BoTorchGenerator(), + transforms=[MetadataToTask], + fit_on_init=False, ) - ssd = SearchSpaceDigest( - feature_names=["x1", "x2", "task"], - bounds=[(0, 1), (0, 1), (0, 2)], - task_features=[2], - target_values={2: 0}, - discrete_choices={2: [0, 1, 2]}, + with self.subTest("model_space_has_z"): + self.assertIn("z", adapter._model_space.parameters) + with self.subTest("search_space_unchanged"): + self.assertNotIn("z", adapter._search_space.parameters) + with self.subTest("backfilled_not_source_only"): + self.assertNotIn("z", adapter._source_only_params) + with self.subTest("shared_params_keep_target_bounds"): + x_param = assert_is_instance( + adapter._model_space.parameters["x"], RangeParameter + ) + self.assertEqual(x_param.lower, 0.0) + self.assertEqual(x_param.upper, 1.0) + with self.subTest("source_only_without_backfill"): + source_ss2 = _make_ss({"x": (0, 5), "w": (0, 10)}) + source_exp2 = _gen_experiment( + "source2", num_trials=3, search_space=source_ss2 + ) + target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [ + AuxiliarySource(experiment=source_exp2) + ] + adapter2 = TransferLearningAdapter( + experiment=target_exp, + search_space=target_ss, + data=target_exp.lookup_data(), + generator=BoTorchGenerator(), + transforms=[MetadataToTask], + fit_on_init=False, + ) + self.assertIn("w", adapter2._model_space.parameters) + self.assertIsInstance(adapter2._model_space.parameters["w"], RangeParameter) + + +class GetTargetDataParametersTest(TestCase): + """_get_target_data_parameters filters joint params to target-only + task.""" + + def test_filters_source_only_params(self) -> None: + adapter = MagicMock(spec=TransferLearningAdapter) + adapter._source_only_params = {"z"} + joint_params = ["x", "y", "z", Keys.TASK_FEATURE_NAME.value] + result = TransferLearningAdapter._get_target_data_parameters( + adapter, joint_params ) - result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) - self.assertEqual(result.discrete_choices, {3: [0, 1, 2]}) - self.assertEqual(result.task_features, [3]) + self.assertEqual(result, ["x", "y", Keys.TASK_FEATURE_NAME.value]) + + def test_no_source_only_params_returns_all(self) -> None: + adapter = MagicMock(spec=TransferLearningAdapter) + adapter._source_only_params = set() + params = ["x", "y", Keys.TASK_FEATURE_NAME.value] + result = TransferLearningAdapter._get_target_data_parameters(adapter, params) + self.assertEqual(result, params) - def test_hierarchical_dependencies_at_task_idx_raises(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)}) + +class FitWithDataParametersTest(TestCase): + """After _fit, self.parameters = joint params and SSD has joint bounds, + without needing _expand_ssd_to_joint_space.""" + + @mock_botorch_optimize + def test_fit_heterogeneous_ssd_has_joint_bounds(self) -> None: + target_ss = _make_ss({"x": (0, 1), "y": (0, 1)}) + source_ss = _make_ss({"x": (0, 5), "y": (0, 5), "z": (0, 5)}) + target_exp = _gen_experiment("target", num_trials=3, search_space=target_ss) + source_exp = _gen_experiment("source", num_trials=5, search_space=source_ss) + target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [ + AuxiliarySource(experiment=source_exp) + ] + adapter = TransferLearningAdapter( + experiment=target_exp, + search_space=target_ss, + data=target_exp.lookup_data(), + generator=BoTorchGenerator(), + transforms=[MetadataToTask], + fit_on_init=False, ) - ssd = SearchSpaceDigest( - feature_names=["x1", "x2", "task"], - bounds=[(0, 1), (0, 1), (0, 2)], - task_features=[2], - target_values={2: 0}, - hierarchical_dependencies={2: {0: [1]}}, + adapter.outcomes = list( + none_throws(target_exp.optimization_config).objective.metric_names ) - with self.assertRaisesRegex(UnsupportedError, "hierarchical_dependencies"): - TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) + adapter.parameters = list(target_exp.search_space.parameters) - def test_multiple_task_features_raises(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)}) - ) - ssd = SearchSpaceDigest( - feature_names=["x1", "task1", "task2"], - bounds=[(0, 1), (0, 1), (0, 1)], - task_features=[1, 2], + with patch.object( + adapter.generator, "fit", wraps=none_throws(adapter.generator).fit + ) as gen_fit: + experiment_data, search_space = adapter._process_and_transform_data( + experiment=target_exp, + ) + adapter._fit(search_space=search_space, experiment_data=experiment_data) + + gen_fit.assert_called_once() + ssd = gen_fit.call_args[1]["search_space_digest"] + # SSD feature names include source-only z (from the joint model space) + self.assertIn("z", ssd.feature_names) + # SSD bounds cover the joint space + z_idx = ssd.feature_names.index("z") + self.assertEqual(ssd.bounds[z_idx], (0.0, 5.0)) + # self.parameters is the joint set + self.assertIn("z", adapter.parameters) + # Task feature is last + self.assertEqual(adapter.parameters[-1], Keys.TASK_FEATURE_NAME.value) + + @mock_botorch_optimize + def test_fit_and_gen_heterogeneous(self) -> None: + """Full fit+gen round-trip with heterogeneous search spaces. + No status_quo on source, so z is truly source-only (no backfill).""" + target_ss = _make_ss({"x": (0, 1), "y": (0, 1)}) + source_ss = _make_ss({"x": (0, 5), "y": (0, 5), "z": (0, 5)}) + target_exp = _gen_experiment("target", num_trials=3, search_space=target_ss) + source_exp = _gen_experiment("source", num_trials=5, search_space=source_ss) + target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [ + AuxiliarySource(experiment=source_exp) + ] + adapter = TransferLearningAdapter( + experiment=target_exp, + search_space=target_ss, + data=target_exp.lookup_data(), + generator=BoTorchGenerator(), + transforms=[MetadataToTask], + fit_on_init=True, ) - with self.assertRaisesRegex(UnsupportedError, "Multiple task features"): - TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) + gr = adapter.gen(n=1) + # Generated arms should only have target params + for arm in gr.arms: + self.assertIn("x", arm.parameters) + self.assertIn("y", arm.parameters) + self.assertNotIn("z", arm.parameters)