77
88from __future__ import annotations
99
10- import dataclasses
1110import warnings
1211from collections .abc import Mapping , Sequence
1312from logging import Logger
3938from ax .core .observation import ObservationData , ObservationFeatures
4039from ax .core .optimization_config import OptimizationConfig
4140from ax .core .parameter import FixedParameter , RangeParameter
42- from ax .core .search_space import SearchSpace , SearchSpaceDigest
41+ from ax .core .search_space import SearchSpace
4342from ax .exceptions .core import DataRequiredError , UnsupportedError , UserInputError
4443from ax .generation_strategy .best_model_selector import (
4544 ReductionCriterion ,
@@ -93,6 +92,7 @@ def __init__(
9392 ``experiment.auxiliary_experiments_by_purpose`` with type set to
9493 ``AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT``.
9594 """
95+ self ._source_only_params : set [str ] = set ()
9696 transforms = [] if transforms is None else list (transforms )
9797 if MetadataToTask not in transforms :
9898 raise UserInputError (
@@ -150,22 +150,7 @@ def __init__(
150150 target_search_space = search_space ,
151151 )
152152
153- # Add source-only backfilled params as FixedParameter to the target
154- # search space so that the compatibility check passes and the model
155- # space includes these params (FixedToTunable will later convert them
156- # to RangeParameter using the joint space bounds).
157153 search_space = search_space .clone () # avoid mutating caller's object
158- target_param_names = set (search_space .parameters .keys ())
159- for name , param in self .joint_search_space .parameters .items ():
160- if name not in target_param_names and param .backfill_value is not None :
161- search_space .add_parameter (
162- FixedParameter (
163- name = name ,
164- parameter_type = param .parameter_type ,
165- value = param .backfill_value ,
166- )
167- )
168-
169154 # Include backfill param names in filled_params so Phase 1 of
170155 # check_search_space_compatibility passes for target-only params.
171156 filled_params .extend (self .joint_search_space .backfill_values ().keys ())
@@ -196,6 +181,28 @@ def __init__(
196181 default_model_gen_options = default_model_gen_options ,
197182 )
198183
184+ def _set_search_space (self , search_space : SearchSpace ) -> None :
185+ """Set search space and model space for transfer learning.
186+
187+ Overrides the base class to add source-only params (as RangeParameters)
188+ to ``_model_space`` while preserving target bounds for shared params.
189+ This ensures the SSD naturally covers the full joint feature space
190+ without needing post-hoc expansion, and Normalize is anchored to target
191+ bounds so target data maps to [0, 1].
192+ """
193+ self ._search_space = search_space .clone ()
194+ model_space = search_space .clone ()
195+ self ._source_only_params = set ()
196+ for name , param in self .joint_search_space .parameters .items ():
197+ if name not in model_space .parameters and isinstance (param , RangeParameter ):
198+ model_space .add_parameter (param .clone ())
199+ # Only mark as source-only if no backfill value exists.
200+ # Backfilled params have known values (via FillMissingParameters)
201+ # and should be included in data extraction.
202+ if param .backfill_value is None :
203+ self ._source_only_params .add (name )
204+ self ._model_space = model_space
205+
199206 def _transform_data (
200207 self ,
201208 experiment_data : ExperimentData ,
@@ -505,94 +512,18 @@ def _get_task_datasets(
505512 )
506513 return task_datasets
507514
508- def _expand_ssd_to_joint_space (
515+ def _get_target_data_parameters (
509516 self ,
510- search_space_digest : SearchSpaceDigest ,
511- ) -> SearchSpaceDigest :
512- """Expand SSD bounds and feature_names to cover the joint search space.
513-
514- The SSD produced by ``_get_fit_args`` reflects the target search space.
515- When source experiments have additional parameters, the model operates
516- in the full joint feature space. This method appends bounds and feature
517- names for source-only parameters so that input transforms receive
518- correct full-space bounds.
517+ all_params : list [str ],
518+ ) -> list [str ]:
519+ """Filter a joint parameter list to target-only params + task feature.
520+
521+ Source-only params (those added by ``_set_search_space`` from the joint
522+ space) are excluded because the target experiment data does not have
523+ those columns. Uses untransformed names, which are stable across
524+ transforms (Range params are never renamed by OneHot, IntToFloat, etc.).
519525 """
520- existing_names = set (search_space_digest .feature_names )
521- extra_names : list [str ] = []
522- extra_bounds : list [tuple [int | float , int | float ]] = []
523- # Only collect parameters absent from the target SSD. Shared
524- # parameters that appear in both target and source keep the target
525- # bounds -- source observations outside those bounds will normalize
526- # outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527- # for a __target__ task in [0, 1]^D.
528- for name , param in self .joint_search_space .parameters .items ():
529- if name not in existing_names and isinstance (param , RangeParameter ):
530- extra_names .append (name )
531- extra_bounds .append ((param .lower , param .upper ))
532- if not extra_names :
533- return search_space_digest
534- # Insert source-only params before the task feature
535- task_features = search_space_digest .task_features
536- if len (task_features ) == 1 :
537- tf_idx = task_features [0 ]
538- names = list (search_space_digest .feature_names )
539- bounds = list (search_space_digest .bounds )
540- # Raise if index-based fields (other than the task feature
541- # itself) reference indices at or above tf_idx, since we would
542- # need to shift them when inserting extra params.
543- for field_name in (
544- "ordinal_features" ,
545- "categorical_features" ,
546- "fidelity_features" ,
547- ):
548- indices = getattr (search_space_digest , field_name )
549- if any (i >= tf_idx for i in indices ):
550- raise UnsupportedError (
551- f"Cannot expand SSD: { field_name } contains index >= { tf_idx } ."
552- )
553- if any (
554- i >= tf_idx and i not in task_features
555- for i in search_space_digest .discrete_choices
556- ):
557- raise UnsupportedError (
558- f"Cannot expand SSD: discrete_choices contains index >= { tf_idx } ."
559- )
560- if search_space_digest .hierarchical_dependencies is not None and any (
561- i >= tf_idx for i in search_space_digest .hierarchical_dependencies
562- ):
563- raise UnsupportedError (
564- "Cannot expand SSD: hierarchical_dependencies contains "
565- f"index >= { tf_idx } ."
566- )
567- names [tf_idx :tf_idx ] = extra_names
568- bounds [tf_idx :tf_idx ] = extra_bounds
569- n_extra = len (extra_names )
570- new_task_features = [tf_idx + n_extra ]
571- new_target_values = dict (search_space_digest .target_values )
572- if tf_idx in new_target_values :
573- new_target_values [new_task_features [0 ]] = new_target_values .pop (tf_idx )
574- new_discrete = dict (search_space_digest .discrete_choices )
575- if tf_idx in new_discrete :
576- new_discrete [new_task_features [0 ]] = new_discrete .pop (tf_idx )
577- return dataclasses .replace (
578- search_space_digest ,
579- feature_names = names ,
580- bounds = bounds ,
581- task_features = new_task_features ,
582- target_values = new_target_values ,
583- discrete_choices = new_discrete ,
584- )
585- elif len (task_features ) == 0 :
586- # No task feature -- just append.
587- return dataclasses .replace (
588- search_space_digest ,
589- feature_names = search_space_digest .feature_names + extra_names ,
590- bounds = search_space_digest .bounds + extra_bounds ,
591- )
592- else :
593- raise UnsupportedError (
594- "Multiple task features are not supported in transfer learning."
595- )
526+ return [p for p in all_params if p not in self ._source_only_params ]
596527
597528 def _fit (
598529 self ,
@@ -610,15 +541,20 @@ def _fit(
610541 if experiment_data .arm_data .empty :
611542 # Temporarily unset self.outcomes to avoid an error in _get_fit_args.
612543 self .outcomes = []
544+ # Pre-compute the joint param ordering (mirrors _get_fit_args logic)
545+ # so we can derive the target-only subset for data extraction.
546+ all_params = list (search_space .parameters .keys ())
547+ task_name = Keys .TASK_FEATURE_NAME .value
548+ if task_name in all_params :
549+ all_params .remove (task_name )
550+ all_params .append (task_name )
551+ target_data_params = self ._get_target_data_parameters (all_params )
613552 datasets , candidate_metadata , search_space_digest = self ._get_fit_args (
614553 search_space = search_space ,
615554 experiment_data = experiment_data ,
616555 update_outcomes_and_parameters = True ,
556+ data_parameters = target_data_params ,
617557 )
618- # Expand SSD bounds to cover source-only params from the joint search
619- # space. This ensures Normalize (and other input transforms) get bounds
620- # for the full feature space, not just the target dims.
621- search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
622558 if experiment_data .arm_data .empty :
623559 self .outcomes = outcomes
624560 # Temporarily set datasets to None. We will construct empty datasets
@@ -656,12 +592,13 @@ def _cross_validate(
656592 ) -> list [ObservationData ]:
657593 if self .parameters is None :
658594 raise ValueError (FIT_MODEL_ERROR .format (action = "_cross_validate" ))
595+ target_data_params = self ._get_target_data_parameters (self .parameters )
659596 datasets , _ , search_space_digest = self ._get_fit_args (
660597 search_space = search_space ,
661598 experiment_data = cv_training_data ,
662599 update_outcomes_and_parameters = False ,
600+ data_parameters = target_data_params ,
663601 )
664- search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
665602 # Add the task feature to SSD, to ensure that a multi-task model is selected.
666603 if len (search_space_digest .task_features ) > 1 :
667604 raise UnsupportedError (
@@ -714,23 +651,22 @@ def gen(
714651 to ``RemoveFixed.untransform_observation_features``, which requires updating
715652 the signature of all transforms.
716653 """
717- # If a fixed parameter in the target search space, is
718- # a range parameter in the joint search space, then we
719- # should set it as a fixed feature here.
654+ # If a fixed parameter in the target search space is a range
655+ # parameter in the joint search space, pin it as a fixed feature.
720656 search_space = search_space or self ._search_space
721657 for name , target_p in search_space .parameters .items ():
722658 if (
723659 isinstance (target_p , FixedParameter )
724660 and (p := self .joint_search_space .parameters .get (name )) is not None
725661 and isinstance (p , RangeParameter )
726662 ):
727- # add to fixed features
728663 if fixed_features is None :
729664 fixed_features = ObservationFeatures (parameters = {})
730665 fixed_features .parameters .setdefault (name , target_p .value )
731- # Fix source-only params so the optimizer doesn't search over them.
732- # Center is a reasonable default; LearnedFeatureImputation overwrites
733- # these with learned values when configured.
666+ # Fix source-only params that ARE in the search space (e.g. injected
667+ # as FixedParam with a backfill value) so the optimizer doesn't search
668+ # over them. Params NOT in the search space are handled by the model
669+ # internally (HeterogeneousMTGP natively, LFI for MultiTaskGP).
734670 joint_center = self .joint_search_space .compute_naive_center ()
735671 for name , param in self .joint_search_space .parameters .items ():
736672 if name not in search_space .parameters and isinstance (
@@ -739,6 +675,20 @@ def gen(
739675 if fixed_features is None :
740676 fixed_features = ObservationFeatures (parameters = {})
741677 fixed_features .parameters .setdefault (name , joint_center [name ])
678+ # At gen time, optimize over the target search space only.
679+ # _source_only_params covers params without backfill (stable
680+ # untransformed names). Backfilled source-only params are identified
681+ # by checking joint SS backfill_values — their names are also stable
682+ # (RangeParameters are never renamed by transforms).
683+ saved_parameters = self .parameters
684+ backfilled_source_only = {
685+ name
686+ for name , param in self .joint_search_space .parameters .items ()
687+ if name not in self ._experiment .search_space .parameters
688+ and param .backfill_value is not None
689+ }
690+ exclude = self ._source_only_params | backfilled_source_only
691+ self .parameters = [p for p in self .parameters if p not in exclude ]
742692 generator_run = super ().gen (
743693 n = n ,
744694 search_space = search_space ,
@@ -747,6 +697,7 @@ def gen(
747697 fixed_features = fixed_features ,
748698 model_gen_options = model_gen_options ,
749699 )
700+ self .parameters = saved_parameters
750701 # Remove the parameters that are not in the target experiment's search
751702 # space, and update candidate_metadata_by_arm_signature to reflect the
752703 # new arm. We use the experiment's search space rather than
0 commit comments