77
88from __future__ import annotations
99
10- import dataclasses
1110import warnings
1211from collections .abc import Mapping , Sequence
1312from logging import Logger
3938from ax .core .observation import ObservationData , ObservationFeatures
4039from ax .core .optimization_config import OptimizationConfig
4140from ax .core .parameter import FixedParameter , RangeParameter
42- from ax .core .search_space import SearchSpace , SearchSpaceDigest
41+ from ax .core .search_space import SearchSpace
4342from ax .exceptions .core import DataRequiredError , UnsupportedError , UserInputError
4443from ax .generation_strategy .best_model_selector import (
4544 ReductionCriterion ,
5655from botorch .models .transforms .input import InputTransform , Normalize
5756from botorch .utils .datasets import MultiTaskDataset , SupervisedDataset
5857from gpytorch .kernels .kernel import Kernel
59- from pyre_extensions import assert_is_instance
58+ from pyre_extensions import assert_is_instance , override
6059
6160TL_EXP : AuxiliaryExperimentPurpose = AuxiliaryExperimentPurpose .TRANSFERABLE_EXPERIMENT
6261TARGET_TASK_VALUE : int = 0
@@ -93,6 +92,7 @@ def __init__(
9392 ``experiment.auxiliary_experiments_by_purpose`` with type set to
9493 ``AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT``.
9594 """
95+ self ._source_only_params : set [str ] = set ()
9696 transforms = [] if transforms is None else list (transforms )
9797 if MetadataToTask not in transforms :
9898 raise UserInputError (
@@ -150,22 +150,7 @@ def __init__(
150150 target_search_space = search_space ,
151151 )
152152
153- # Add source-only backfilled params as FixedParameter to the target
154- # search space so that the compatibility check passes and the model
155- # space includes these params (FixedToTunable will later convert them
156- # to RangeParameter using the joint space bounds).
157153 search_space = search_space .clone () # avoid mutating caller's object
158- target_param_names = set (search_space .parameters .keys ())
159- for name , param in self .joint_search_space .parameters .items ():
160- if name not in target_param_names and param .backfill_value is not None :
161- search_space .add_parameter (
162- FixedParameter (
163- name = name ,
164- parameter_type = param .parameter_type ,
165- value = param .backfill_value ,
166- )
167- )
168-
169154 # Include backfill param names in filled_params so Phase 1 of
170155 # check_search_space_compatibility passes for target-only params.
171156 filled_params .extend (self .joint_search_space .backfill_values ().keys ())
@@ -177,7 +162,26 @@ def __init__(
177162 )
178163 self ._heterogeneous_search_space : bool = False
179164 except (UserInputError , ValueError ):
180- self ._heterogeneous_search_space : bool = True
165+ # The compatibility check rejects source params absent from
166+ # the target. If every gap in both directions is fillable,
167+ # FillMissingParameters makes the data homogeneous.
168+ target_keys = set (search_space .parameters .keys ())
169+ backfill_keys = set (self .joint_search_space .backfill_values ().keys ())
170+ filled = set (filled_params )
171+ source_gaps_ok = all (
172+ n in backfill_keys
173+ for n in self .joint_search_space .parameters
174+ if n not in target_keys
175+ )
176+ target_gaps_ok = all (
177+ s .transfer_param_config .get (n , n )
178+ in s .experiment .search_space .parameters
179+ or n in filled
180+ for s in self .auxiliary_sources
181+ for n , p in search_space .parameters .items ()
182+ if not isinstance (p , FixedParameter )
183+ )
184+ self ._heterogeneous_search_space = not (source_gaps_ok and target_gaps_ok )
181185
182186 self ._task_value : int = TARGET_TASK_VALUE
183187
@@ -196,6 +200,29 @@ def __init__(
196200 default_model_gen_options = default_model_gen_options ,
197201 )
198202
203+ @override
204+ def _set_search_space (self , search_space : SearchSpace ) -> None :
205+ """Set search space and model space for transfer learning.
206+
207+ Overrides the base class to add source-only params (as RangeParameters)
208+ to ``_model_space`` while preserving target bounds for shared params.
209+ This ensures the SSD naturally covers the full joint feature space
210+ without needing post-hoc expansion, and Normalize is anchored to target
211+ bounds so target data maps to [0, 1].
212+ """
213+ self ._search_space = search_space .clone ()
214+ model_space = search_space .clone ()
215+ self ._source_only_params = set ()
216+ for name , param in self .joint_search_space .parameters .items ():
217+ if name not in model_space .parameters and isinstance (param , RangeParameter ):
218+ model_space .add_parameter (param .clone ())
219+ # Only mark as source-only if no backfill value exists.
220+ # Backfilled params have known values (via FillMissingParameters)
221+ # and should be included in data extraction.
222+ if param .backfill_value is None :
223+ self ._source_only_params .add (name )
224+ self ._model_space = model_space
225+
199226 def _transform_data (
200227 self ,
201228 experiment_data : ExperimentData ,
@@ -505,94 +532,18 @@ def _get_task_datasets(
505532 )
506533 return task_datasets
507534
508- def _expand_ssd_to_joint_space (
535+ def _get_target_data_parameters (
509536 self ,
510- search_space_digest : SearchSpaceDigest ,
511- ) -> SearchSpaceDigest :
512- """Expand SSD bounds and feature_names to cover the joint search space.
513-
514- The SSD produced by ``_get_fit_args`` reflects the target search space.
515- When source experiments have additional parameters, the model operates
516- in the full joint feature space. This method appends bounds and feature
517- names for source-only parameters so that input transforms receive
518- correct full-space bounds.
537+ all_params : list [str ],
538+ ) -> list [str ]:
539+ """Filter a joint parameter list to target-only params + task feature.
540+
541+ Source-only params (those added by ``_set_search_space`` from the joint
542+ space) are excluded because the target experiment data does not have
543+ those columns. Uses untransformed names, which are stable across
544+ transforms (Range params are never renamed by OneHot, IntToFloat, etc.).
519545 """
520- existing_names = set (search_space_digest .feature_names )
521- extra_names : list [str ] = []
522- extra_bounds : list [tuple [int | float , int | float ]] = []
523- # Only collect parameters absent from the target SSD. Shared
524- # parameters that appear in both target and source keep the target
525- # bounds -- source observations outside those bounds will normalize
526- # outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527- # for a __target__ task in [0, 1]^D.
528- for name , param in self .joint_search_space .parameters .items ():
529- if name not in existing_names and isinstance (param , RangeParameter ):
530- extra_names .append (name )
531- extra_bounds .append ((param .lower , param .upper ))
532- if not extra_names :
533- return search_space_digest
534- # Insert source-only params before the task feature
535- task_features = search_space_digest .task_features
536- if len (task_features ) == 1 :
537- tf_idx = task_features [0 ]
538- names = list (search_space_digest .feature_names )
539- bounds = list (search_space_digest .bounds )
540- # Raise if index-based fields (other than the task feature
541- # itself) reference indices at or above tf_idx, since we would
542- # need to shift them when inserting extra params.
543- for field_name in (
544- "ordinal_features" ,
545- "categorical_features" ,
546- "fidelity_features" ,
547- ):
548- indices = getattr (search_space_digest , field_name )
549- if any (i >= tf_idx for i in indices ):
550- raise UnsupportedError (
551- f"Cannot expand SSD: { field_name } contains index >= { tf_idx } ."
552- )
553- if any (
554- i >= tf_idx and i not in task_features
555- for i in search_space_digest .discrete_choices
556- ):
557- raise UnsupportedError (
558- f"Cannot expand SSD: discrete_choices contains index >= { tf_idx } ."
559- )
560- if search_space_digest .hierarchical_dependencies is not None and any (
561- i >= tf_idx for i in search_space_digest .hierarchical_dependencies
562- ):
563- raise UnsupportedError (
564- "Cannot expand SSD: hierarchical_dependencies contains "
565- f"index >= { tf_idx } ."
566- )
567- names [tf_idx :tf_idx ] = extra_names
568- bounds [tf_idx :tf_idx ] = extra_bounds
569- n_extra = len (extra_names )
570- new_task_features = [tf_idx + n_extra ]
571- new_target_values = dict (search_space_digest .target_values )
572- if tf_idx in new_target_values :
573- new_target_values [new_task_features [0 ]] = new_target_values .pop (tf_idx )
574- new_discrete = dict (search_space_digest .discrete_choices )
575- if tf_idx in new_discrete :
576- new_discrete [new_task_features [0 ]] = new_discrete .pop (tf_idx )
577- return dataclasses .replace (
578- search_space_digest ,
579- feature_names = names ,
580- bounds = bounds ,
581- task_features = new_task_features ,
582- target_values = new_target_values ,
583- discrete_choices = new_discrete ,
584- )
585- elif len (task_features ) == 0 :
586- # No task feature -- just append.
587- return dataclasses .replace (
588- search_space_digest ,
589- feature_names = search_space_digest .feature_names + extra_names ,
590- bounds = search_space_digest .bounds + extra_bounds ,
591- )
592- else :
593- raise UnsupportedError (
594- "Multiple task features are not supported in transfer learning."
595- )
546+ return [p for p in all_params if p not in self ._source_only_params ]
596547
597548 def _fit (
598549 self ,
@@ -610,15 +561,20 @@ def _fit(
610561 if experiment_data .arm_data .empty :
611562 # Temporarily unset self.outcomes to avoid an error in _get_fit_args.
612563 self .outcomes = []
564+ # Pre-compute the joint param ordering (mirrors _get_fit_args logic)
565+ # so we can derive the target-only subset for data extraction.
566+ all_params = list (search_space .parameters .keys ())
567+ task_name = Keys .TASK_FEATURE_NAME .value
568+ if task_name in all_params :
569+ all_params .remove (task_name )
570+ all_params .append (task_name )
571+ target_data_params = self ._get_target_data_parameters (all_params )
613572 datasets , candidate_metadata , search_space_digest = self ._get_fit_args (
614573 search_space = search_space ,
615574 experiment_data = experiment_data ,
616575 update_outcomes_and_parameters = True ,
576+ data_parameters = target_data_params ,
617577 )
618- # Expand SSD bounds to cover source-only params from the joint search
619- # space. This ensures Normalize (and other input transforms) get bounds
620- # for the full feature space, not just the target dims.
621- search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
622578 if experiment_data .arm_data .empty :
623579 self .outcomes = outcomes
624580 # Temporarily set datasets to None. We will construct empty datasets
@@ -656,12 +612,13 @@ def _cross_validate(
656612 ) -> list [ObservationData ]:
657613 if self .parameters is None :
658614 raise ValueError (FIT_MODEL_ERROR .format (action = "_cross_validate" ))
615+ target_data_params = self ._get_target_data_parameters (self .parameters )
659616 datasets , _ , search_space_digest = self ._get_fit_args (
660617 search_space = search_space ,
661618 experiment_data = cv_training_data ,
662619 update_outcomes_and_parameters = False ,
620+ data_parameters = target_data_params ,
663621 )
664- search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
665622 # Add the task feature to SSD, to ensure that a multi-task model is selected.
666623 if len (search_space_digest .task_features ) > 1 :
667624 raise UnsupportedError (
@@ -714,23 +671,22 @@ def gen(
714671 to ``RemoveFixed.untransform_observation_features``, which requires updating
715672 the signature of all transforms.
716673 """
717- # If a fixed parameter in the target search space, is
718- # a range parameter in the joint search space, then we
719- # should set it as a fixed feature here.
674+ # If a fixed parameter in the target search space is a range
675+ # parameter in the joint search space, pin it as a fixed feature.
720676 search_space = search_space or self ._search_space
721677 for name , target_p in search_space .parameters .items ():
722678 if (
723679 isinstance (target_p , FixedParameter )
724680 and (p := self .joint_search_space .parameters .get (name )) is not None
725681 and isinstance (p , RangeParameter )
726682 ):
727- # add to fixed features
728683 if fixed_features is None :
729684 fixed_features = ObservationFeatures (parameters = {})
730685 fixed_features .parameters .setdefault (name , target_p .value )
731- # Fix source-only params so the optimizer doesn't search over them.
732- # Center is a reasonable default; LearnedFeatureImputation overwrites
733- # these with learned values when configured.
686+ # Fix source-only params that ARE in the search space (e.g. injected
687+ # as FixedParam with a backfill value) so the optimizer doesn't search
688+ # over them. Params NOT in the search space are handled by the model
689+ # internally (HeterogeneousMTGP natively, LFI for MultiTaskGP).
734690 joint_center = self .joint_search_space .compute_naive_center ()
735691 for name , param in self .joint_search_space .parameters .items ():
736692 if name not in search_space .parameters and isinstance (
@@ -739,6 +695,20 @@ def gen(
739695 if fixed_features is None :
740696 fixed_features = ObservationFeatures (parameters = {})
741697 fixed_features .parameters .setdefault (name , joint_center [name ])
698+ # At gen time, optimize over the target search space only.
699+ # _source_only_params covers params without backfill (stable
700+ # untransformed names). Backfilled source-only params are identified
701+ # by checking joint SS backfill_values — their names are also stable
702+ # (RangeParameters are never renamed by transforms).
703+ saved_parameters = self .parameters
704+ backfilled_source_only = {
705+ name
706+ for name , param in self .joint_search_space .parameters .items ()
707+ if name not in self ._experiment .search_space .parameters
708+ and param .backfill_value is not None
709+ }
710+ exclude = self ._source_only_params | backfilled_source_only
711+ self .parameters = [p for p in self .parameters if p not in exclude ]
742712 generator_run = super ().gen (
743713 n = n ,
744714 search_space = search_space ,
@@ -747,6 +717,7 @@ def gen(
747717 fixed_features = fixed_features ,
748718 model_gen_options = model_gen_options ,
749719 )
720+ self .parameters = saved_parameters
750721 # Remove the parameters that are not in the target experiment's search
751722 # space, and update candidate_metadata_by_arm_signature to reflect the
752723 # new arm. We use the experiment's search space rather than
0 commit comments