Skip to content

Commit 3a30fb8

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Decouple parameter discovery from data extraction in _get_fit_args (#5200)
Summary: **Motivation:** model_space/search_space was not properly used - parameter bounds on the search space would be set from the union of source and target, and parameters that were fixed on the target would be RangeParameters if the model_space contained a Fixed/Range change in the parameter. Adds a `data_parameters` argument to `TorchAdapter._get_fit_args` that decouples SSD construction (model params) from data column extraction (target params). This lets the TL adapter set `_model_space` to include source-only RangeParameters directly, so the SSD naturally covers the full joint feature space -- eliminating the need for the `_expand_ssd_to_joint_space` post-hoc expansion. Reviewed By: saitcakmak Differential Revision: D104702983
1 parent bd2ff97 commit 3a30fb8

3 files changed

Lines changed: 210 additions & 215 deletions

File tree

ax/adapter/torch.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ def _get_fit_args(
774774
search_space: SearchSpace,
775775
experiment_data: ExperimentData,
776776
update_outcomes_and_parameters: bool,
777+
data_parameters: list[str] | None = None,
777778
) -> tuple[
778779
list[SupervisedDataset],
779780
list[list[TCandidateMetadata]] | None,
@@ -791,6 +792,11 @@ def _get_fit_args(
791792
update_outcomes_and_parameters: Whether to update `self.outcomes` with
792793
all outcomes found in the observations and `self.parameters` with
793794
all parameters in the search space. Typically only used in `_fit`.
795+
data_parameters: When provided, columns to extract from
796+
``experiment_data``. Defaults to ``self.parameters``. Useful when
797+
the model space is larger than the data (e.g. transfer learning
798+
with heterogeneous search spaces where the data only contains
799+
target columns but the model operates in a joint feature space).
794800
795801
Returns:
796802
The datasets & metadata, extracted from the ``experiment_data``, and the
@@ -818,12 +824,23 @@ def _get_fit_args(
818824
search_space_digest = extract_search_space_digest(
819825
search_space=search_space, param_names=self.parameters
820826
)
827+
extract_params = (
828+
data_parameters if data_parameters is not None else self.parameters
829+
)
830+
# When data_parameters differs from self.parameters, the SSD's
831+
# task_feature indices don't match the data columns. Pass None
832+
# to skip MultiTaskDataset wrapping (the caller handles it).
833+
extraction_ssd = (
834+
None
835+
if data_parameters is not None and data_parameters != self.parameters
836+
else search_space_digest
837+
)
821838
# Convert observations to datasets
822839
datasets, ordered_outcomes, candidate_metadata = self._convert_experiment_data(
823840
experiment_data=experiment_data,
824841
outcomes=self.outcomes,
825-
parameters=self.parameters,
826-
search_space_digest=search_space_digest,
842+
parameters=extract_params,
843+
search_space_digest=extraction_ssd,
827844
)
828845
datasets = self._update_w_aux_exp_datasets(datasets=datasets)
829846

ax/adapter/transfer_learning/adapter.py

Lines changed: 88 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from __future__ import annotations
99

10-
import dataclasses
1110
import warnings
1211
from collections.abc import Mapping, Sequence
1312
from logging import Logger
@@ -39,7 +38,7 @@
3938
from ax.core.observation import ObservationData, ObservationFeatures
4039
from ax.core.optimization_config import OptimizationConfig
4140
from ax.core.parameter import FixedParameter, RangeParameter
42-
from ax.core.search_space import SearchSpace, SearchSpaceDigest
41+
from ax.core.search_space import SearchSpace
4342
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
4443
from ax.generation_strategy.best_model_selector import (
4544
ReductionCriterion,
@@ -56,7 +55,7 @@
5655
from botorch.models.transforms.input import InputTransform, Normalize
5756
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
5857
from gpytorch.kernels.kernel import Kernel
59-
from pyre_extensions import assert_is_instance
58+
from pyre_extensions import assert_is_instance, override
6059

6160
TL_EXP: AuxiliaryExperimentPurpose = AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT
6261
TARGET_TASK_VALUE: int = 0
@@ -93,6 +92,7 @@ def __init__(
9392
``experiment.auxiliary_experiments_by_purpose`` with type set to
9493
``AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT``.
9594
"""
95+
self._source_only_params: set[str] = set()
9696
transforms = [] if transforms is None else list(transforms)
9797
if MetadataToTask not in transforms:
9898
raise UserInputError(
@@ -150,22 +150,7 @@ def __init__(
150150
target_search_space=search_space,
151151
)
152152

153-
# Add source-only backfilled params as FixedParameter to the target
154-
# search space so that the compatibility check passes and the model
155-
# space includes these params (FixedToTunable will later convert them
156-
# to RangeParameter using the joint space bounds).
157153
search_space = search_space.clone() # avoid mutating caller's object
158-
target_param_names = set(search_space.parameters.keys())
159-
for name, param in self.joint_search_space.parameters.items():
160-
if name not in target_param_names and param.backfill_value is not None:
161-
search_space.add_parameter(
162-
FixedParameter(
163-
name=name,
164-
parameter_type=param.parameter_type,
165-
value=param.backfill_value,
166-
)
167-
)
168-
169154
# Include backfill param names in filled_params so Phase 1 of
170155
# check_search_space_compatibility passes for target-only params.
171156
filled_params.extend(self.joint_search_space.backfill_values().keys())
@@ -177,7 +162,26 @@ def __init__(
177162
)
178163
self._heterogeneous_search_space: bool = False
179164
except (UserInputError, ValueError):
180-
self._heterogeneous_search_space: bool = True
165+
# The compatibility check rejects source params absent from
166+
# the target. If every gap in both directions is fillable,
167+
# FillMissingParameters makes the data homogeneous.
168+
target_keys = set(search_space.parameters.keys())
169+
backfill_keys = set(self.joint_search_space.backfill_values().keys())
170+
filled = set(filled_params)
171+
source_gaps_ok = all(
172+
n in backfill_keys
173+
for n in self.joint_search_space.parameters
174+
if n not in target_keys
175+
)
176+
target_gaps_ok = all(
177+
s.transfer_param_config.get(n, n)
178+
in s.experiment.search_space.parameters
179+
or n in filled
180+
for s in self.auxiliary_sources
181+
for n, p in search_space.parameters.items()
182+
if not isinstance(p, FixedParameter)
183+
)
184+
self._heterogeneous_search_space = not (source_gaps_ok and target_gaps_ok)
181185

182186
self._task_value: int = TARGET_TASK_VALUE
183187

@@ -196,6 +200,29 @@ def __init__(
196200
default_model_gen_options=default_model_gen_options,
197201
)
198202

203+
@override
204+
def _set_search_space(self, search_space: SearchSpace) -> None:
205+
"""Set search space and model space for transfer learning.
206+
207+
Overrides the base class to add source-only params (as RangeParameters)
208+
to ``_model_space`` while preserving target bounds for shared params.
209+
This ensures the SSD naturally covers the full joint feature space
210+
without needing post-hoc expansion, and Normalize is anchored to target
211+
bounds so target data maps to [0, 1].
212+
"""
213+
self._search_space = search_space.clone()
214+
model_space = search_space.clone()
215+
self._source_only_params = set()
216+
for name, param in self.joint_search_space.parameters.items():
217+
if name not in model_space.parameters and isinstance(param, RangeParameter):
218+
model_space.add_parameter(param.clone())
219+
# Only mark as source-only if no backfill value exists.
220+
# Backfilled params have known values (via FillMissingParameters)
221+
# and should be included in data extraction.
222+
if param.backfill_value is None:
223+
self._source_only_params.add(name)
224+
self._model_space = model_space
225+
199226
def _transform_data(
200227
self,
201228
experiment_data: ExperimentData,
@@ -505,94 +532,18 @@ def _get_task_datasets(
505532
)
506533
return task_datasets
507534

508-
def _expand_ssd_to_joint_space(
535+
def _get_target_data_parameters(
509536
self,
510-
search_space_digest: SearchSpaceDigest,
511-
) -> SearchSpaceDigest:
512-
"""Expand SSD bounds and feature_names to cover the joint search space.
513-
514-
The SSD produced by ``_get_fit_args`` reflects the target search space.
515-
When source experiments have additional parameters, the model operates
516-
in the full joint feature space. This method appends bounds and feature
517-
names for source-only parameters so that input transforms receive
518-
correct full-space bounds.
537+
all_params: list[str],
538+
) -> list[str]:
539+
"""Filter a joint parameter list to target-only params + task feature.
540+
541+
Source-only params (those added by ``_set_search_space`` from the joint
542+
space) are excluded because the target experiment data does not have
543+
those columns. Uses untransformed names, which are stable across
544+
transforms (Range params are never renamed by OneHot, IntToFloat, etc.).
519545
"""
520-
existing_names = set(search_space_digest.feature_names)
521-
extra_names: list[str] = []
522-
extra_bounds: list[tuple[int | float, int | float]] = []
523-
# Only collect parameters absent from the target SSD. Shared
524-
# parameters that appear in both target and source keep the target
525-
# bounds -- source observations outside those bounds will normalize
526-
# outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527-
# for a __target__ task in [0, 1]^D.
528-
for name, param in self.joint_search_space.parameters.items():
529-
if name not in existing_names and isinstance(param, RangeParameter):
530-
extra_names.append(name)
531-
extra_bounds.append((param.lower, param.upper))
532-
if not extra_names:
533-
return search_space_digest
534-
# Insert source-only params before the task feature
535-
task_features = search_space_digest.task_features
536-
if len(task_features) == 1:
537-
tf_idx = task_features[0]
538-
names = list(search_space_digest.feature_names)
539-
bounds = list(search_space_digest.bounds)
540-
# Raise if index-based fields (other than the task feature
541-
# itself) reference indices at or above tf_idx, since we would
542-
# need to shift them when inserting extra params.
543-
for field_name in (
544-
"ordinal_features",
545-
"categorical_features",
546-
"fidelity_features",
547-
):
548-
indices = getattr(search_space_digest, field_name)
549-
if any(i >= tf_idx for i in indices):
550-
raise UnsupportedError(
551-
f"Cannot expand SSD: {field_name} contains index >= {tf_idx}."
552-
)
553-
if any(
554-
i >= tf_idx and i not in task_features
555-
for i in search_space_digest.discrete_choices
556-
):
557-
raise UnsupportedError(
558-
f"Cannot expand SSD: discrete_choices contains index >= {tf_idx}."
559-
)
560-
if search_space_digest.hierarchical_dependencies is not None and any(
561-
i >= tf_idx for i in search_space_digest.hierarchical_dependencies
562-
):
563-
raise UnsupportedError(
564-
"Cannot expand SSD: hierarchical_dependencies contains "
565-
f"index >= {tf_idx}."
566-
)
567-
names[tf_idx:tf_idx] = extra_names
568-
bounds[tf_idx:tf_idx] = extra_bounds
569-
n_extra = len(extra_names)
570-
new_task_features = [tf_idx + n_extra]
571-
new_target_values = dict(search_space_digest.target_values)
572-
if tf_idx in new_target_values:
573-
new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx)
574-
new_discrete = dict(search_space_digest.discrete_choices)
575-
if tf_idx in new_discrete:
576-
new_discrete[new_task_features[0]] = new_discrete.pop(tf_idx)
577-
return dataclasses.replace(
578-
search_space_digest,
579-
feature_names=names,
580-
bounds=bounds,
581-
task_features=new_task_features,
582-
target_values=new_target_values,
583-
discrete_choices=new_discrete,
584-
)
585-
elif len(task_features) == 0:
586-
# No task feature -- just append.
587-
return dataclasses.replace(
588-
search_space_digest,
589-
feature_names=search_space_digest.feature_names + extra_names,
590-
bounds=search_space_digest.bounds + extra_bounds,
591-
)
592-
else:
593-
raise UnsupportedError(
594-
"Multiple task features are not supported in transfer learning."
595-
)
546+
return [p for p in all_params if p not in self._source_only_params]
596547

597548
def _fit(
598549
self,
@@ -610,15 +561,20 @@ def _fit(
610561
if experiment_data.arm_data.empty:
611562
# Temporarily unset self.outcomes to avoid an error in _get_fit_args.
612563
self.outcomes = []
564+
# Pre-compute the joint param ordering (mirrors _get_fit_args logic)
565+
# so we can derive the target-only subset for data extraction.
566+
all_params = list(search_space.parameters.keys())
567+
task_name = Keys.TASK_FEATURE_NAME.value
568+
if task_name in all_params:
569+
all_params.remove(task_name)
570+
all_params.append(task_name)
571+
target_data_params = self._get_target_data_parameters(all_params)
613572
datasets, candidate_metadata, search_space_digest = self._get_fit_args(
614573
search_space=search_space,
615574
experiment_data=experiment_data,
616575
update_outcomes_and_parameters=True,
576+
data_parameters=target_data_params,
617577
)
618-
# Expand SSD bounds to cover source-only params from the joint search
619-
# space. This ensures Normalize (and other input transforms) get bounds
620-
# for the full feature space, not just the target dims.
621-
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
622578
if experiment_data.arm_data.empty:
623579
self.outcomes = outcomes
624580
# Temporarily set datasets to None. We will construct empty datasets
@@ -656,12 +612,13 @@ def _cross_validate(
656612
) -> list[ObservationData]:
657613
if self.parameters is None:
658614
raise ValueError(FIT_MODEL_ERROR.format(action="_cross_validate"))
615+
target_data_params = self._get_target_data_parameters(self.parameters)
659616
datasets, _, search_space_digest = self._get_fit_args(
660617
search_space=search_space,
661618
experiment_data=cv_training_data,
662619
update_outcomes_and_parameters=False,
620+
data_parameters=target_data_params,
663621
)
664-
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
665622
# Add the task feature to SSD, to ensure that a multi-task model is selected.
666623
if len(search_space_digest.task_features) > 1:
667624
raise UnsupportedError(
@@ -714,23 +671,22 @@ def gen(
714671
to ``RemoveFixed.untransform_observation_features``, which requires updating
715672
the signature of all transforms.
716673
"""
717-
# If a fixed parameter in the target search space, is
718-
# a range parameter in the joint search space, then we
719-
# should set it as a fixed feature here.
674+
# If a fixed parameter in the target search space is a range
675+
# parameter in the joint search space, pin it as a fixed feature.
720676
search_space = search_space or self._search_space
721677
for name, target_p in search_space.parameters.items():
722678
if (
723679
isinstance(target_p, FixedParameter)
724680
and (p := self.joint_search_space.parameters.get(name)) is not None
725681
and isinstance(p, RangeParameter)
726682
):
727-
# add to fixed features
728683
if fixed_features is None:
729684
fixed_features = ObservationFeatures(parameters={})
730685
fixed_features.parameters.setdefault(name, target_p.value)
731-
# Fix source-only params so the optimizer doesn't search over them.
732-
# Center is a reasonable default; LearnedFeatureImputation overwrites
733-
# these with learned values when configured.
686+
# Fix source-only params that ARE in the search space (e.g. injected
687+
# as FixedParam with a backfill value) so the optimizer doesn't search
688+
# over them. Params NOT in the search space are handled by the model
689+
# internally (HeterogeneousMTGP natively, LFI for MultiTaskGP).
734690
joint_center = self.joint_search_space.compute_naive_center()
735691
for name, param in self.joint_search_space.parameters.items():
736692
if name not in search_space.parameters and isinstance(
@@ -739,6 +695,20 @@ def gen(
739695
if fixed_features is None:
740696
fixed_features = ObservationFeatures(parameters={})
741697
fixed_features.parameters.setdefault(name, joint_center[name])
698+
# At gen time, optimize over the target search space only.
699+
# _source_only_params covers params without backfill (stable
700+
# untransformed names). Backfilled source-only params are identified
701+
# by checking joint SS backfill_values — their names are also stable
702+
# (RangeParameters are never renamed by transforms).
703+
saved_parameters = self.parameters
704+
backfilled_source_only = {
705+
name
706+
for name, param in self.joint_search_space.parameters.items()
707+
if name not in self._experiment.search_space.parameters
708+
and param.backfill_value is not None
709+
}
710+
exclude = self._source_only_params | backfilled_source_only
711+
self.parameters = [p for p in self.parameters if p not in exclude]
742712
generator_run = super().gen(
743713
n=n,
744714
search_space=search_space,
@@ -747,6 +717,7 @@ def gen(
747717
fixed_features=fixed_features,
748718
model_gen_options=model_gen_options,
749719
)
720+
self.parameters = saved_parameters
750721
# Remove the parameters that are not in the target experiment's search
751722
# space, and update candidate_metadata_by_arm_signature to reflect the
752723
# new arm. We use the experiment's search space rather than

0 commit comments

Comments
 (0)