77
88from __future__ import annotations
99
10+ import dataclasses
1011import warnings
1112from collections .abc import Mapping , Sequence
1213from logging import Logger
3839from ax .core .observation import ObservationData , ObservationFeatures
3940from ax .core .optimization_config import OptimizationConfig
4041from ax .core .parameter import FixedParameter , RangeParameter
41- from ax .core .search_space import SearchSpace
42+ from ax .core .search_space import SearchSpace , SearchSpaceDigest
4243from ax .exceptions .core import DataRequiredError , UnsupportedError , UserInputError
4344from ax .generation_strategy .best_model_selector import (
4445 ReductionCriterion ,
@@ -504,6 +505,95 @@ def _get_task_datasets(
504505 )
505506 return task_datasets
506507
508+ def _expand_ssd_to_joint_space (
509+ self ,
510+ search_space_digest : SearchSpaceDigest ,
511+ ) -> SearchSpaceDigest :
512+ """Expand SSD bounds and feature_names to cover the joint search space.
513+
514+ The SSD produced by ``_get_fit_args`` reflects the target search space.
515+ When source experiments have additional parameters, the model operates
516+ in the full joint feature space. This method appends bounds and feature
517+ names for source-only parameters so that input transforms receive
518+ correct full-space bounds.
519+ """
520+ existing_names = set (search_space_digest .feature_names )
521+ extra_names : list [str ] = []
522+ extra_bounds : list [tuple [int | float , int | float ]] = []
523+ # Only collect parameters absent from the target SSD. Shared
524+ # parameters that appear in both target and source keep the target
525+ # bounds -- source observations outside those bounds will normalize
526+ # outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527+ # for a __target__ task in [0, 1]^D.
528+ for name , param in self .joint_search_space .parameters .items ():
529+ if name not in existing_names and isinstance (param , RangeParameter ):
530+ extra_names .append (name )
531+ extra_bounds .append ((param .lower , param .upper ))
532+ if not extra_names :
533+ return search_space_digest
534+ # Insert source-only params before the task feature
535+ task_features = search_space_digest .task_features
536+ if len (task_features ) == 1 :
537+ tf_idx = task_features [0 ]
538+ names = list (search_space_digest .feature_names )
539+ bounds = list (search_space_digest .bounds )
540+ # Raise if index-based fields (other than the task feature
541+ # itself) reference indices at or above tf_idx, since we would
542+ # need to shift them when inserting extra params.
543+ for field_name in (
544+ "ordinal_features" ,
545+ "categorical_features" ,
546+ "fidelity_features" ,
547+ ):
548+ indices = getattr (search_space_digest , field_name )
549+ if any (i >= tf_idx for i in indices ):
550+ raise UnsupportedError (
551+ f"Cannot expand SSD: { field_name } contains index >= { tf_idx } ."
552+ )
553+ if any (
554+ i >= tf_idx and i not in task_features
555+ for i in search_space_digest .discrete_choices
556+ ):
557+ raise UnsupportedError (
558+ f"Cannot expand SSD: discrete_choices contains index >= { tf_idx } ."
559+ )
560+ if search_space_digest .hierarchical_dependencies is not None and any (
561+ i >= tf_idx for i in search_space_digest .hierarchical_dependencies
562+ ):
563+ raise UnsupportedError (
564+ "Cannot expand SSD: hierarchical_dependencies contains "
565+ f"index >= { tf_idx } ."
566+ )
567+ names [tf_idx :tf_idx ] = extra_names
568+ bounds [tf_idx :tf_idx ] = extra_bounds
569+ n_extra = len (extra_names )
570+ new_task_features = [tf_idx + n_extra ]
571+ new_target_values = dict (search_space_digest .target_values )
572+ if tf_idx in new_target_values :
573+ new_target_values [new_task_features [0 ]] = new_target_values .pop (tf_idx )
574+ new_discrete = dict (search_space_digest .discrete_choices )
575+ if tf_idx in new_discrete :
576+ new_discrete [new_task_features [0 ]] = new_discrete .pop (tf_idx )
577+ return dataclasses .replace (
578+ search_space_digest ,
579+ feature_names = names ,
580+ bounds = bounds ,
581+ task_features = new_task_features ,
582+ target_values = new_target_values ,
583+ discrete_choices = new_discrete ,
584+ )
585+ elif len (task_features ) == 0 :
586+ # No task feature -- just append.
587+ return dataclasses .replace (
588+ search_space_digest ,
589+ feature_names = search_space_digest .feature_names + extra_names ,
590+ bounds = search_space_digest .bounds + extra_bounds ,
591+ )
592+ else :
593+ raise UnsupportedError (
594+ "Multiple task features are not supported in transfer learning."
595+ )
596+
507597 def _fit (
508598 self ,
509599 search_space : SearchSpace ,
@@ -525,6 +615,10 @@ def _fit(
525615 experiment_data = experiment_data ,
526616 update_outcomes_and_parameters = True ,
527617 )
618+ # Expand SSD bounds to cover source-only params from the joint search
619+ # space. This ensures Normalize (and other input transforms) get bounds
620+ # for the full feature space, not just the target dims.
621+ search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
528622 if experiment_data .arm_data .empty :
529623 self .outcomes = outcomes
530624 # Temporarily set datasets to None. We will construct empty datasets
@@ -567,6 +661,7 @@ def _cross_validate(
567661 experiment_data = cv_training_data ,
568662 update_outcomes_and_parameters = False ,
569663 )
664+ search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
570665 # Add the task feature to SSD, to ensure that a multi-task model is selected.
571666 if len (search_space_digest .task_features ) > 1 :
572667 raise UnsupportedError (
@@ -612,7 +707,7 @@ def gen(
612707
613708 Once the ``GeneratorRun`` is produced, it checks for any fixed parameters
614709 that are not in the target search space and removes them. This is a hack
615- around limitations of the ``RemoveFixed`` transform. Since we construct the
710+ around limitations of the Ax ``RemoveFixed`` transform. Since we construct the
616711 transforms with the joint space, we end up adding back all fixed parameters
617712 from the joint space rather than adding only the parameters from the
618713 target search space. A proper fix would require passing in the search space
@@ -633,6 +728,17 @@ def gen(
633728 if fixed_features is None :
634729 fixed_features = ObservationFeatures (parameters = {})
635730 fixed_features .parameters .setdefault (name , target_p .value )
731+ # Fix source-only params so the optimizer doesn't search over them.
732+ # Center is a reasonable default; LearnedFeatureImputation overwrites
733+ # these with learned values when configured.
734+ joint_center = self .joint_search_space .compute_naive_center ()
735+ for name , param in self .joint_search_space .parameters .items ():
736+ if name not in search_space .parameters and isinstance (
737+ param , RangeParameter
738+ ):
739+ if fixed_features is None :
740+ fixed_features = ObservationFeatures (parameters = {})
741+ fixed_features .parameters .setdefault (name , joint_center [name ])
636742 generator_run = super ().gen (
637743 n = n ,
638744 search_space = search_space ,
@@ -719,12 +825,8 @@ def transfer_learning_generator_specs_constructor(
719825 selector in case there is model selection enabled.
720826 """
721827 input_transform_classes : list [type [InputTransform ]] = [Normalize ]
722- input_transform_options = {
723- "Normalize" : {
724- # None for bounds here ensures we do not use bounds from
725- # the search space digest.
726- "bounds" : None ,
727- }
828+ input_transform_options : dict [str , dict [str , Any ]] = {
829+ "Normalize" : {},
728830 }
729831 transforms = transforms or MBM_X_trans + [MetadataToTask ] + Y_trans
730832 transform_configs = get_derelativize_config (
0 commit comments