Skip to content

Commit fda67a6

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Decouple parameter discovery from data extraction in _get_fit_args (#5200)
Summary: Pull Request resolved: #5200 Adds a `data_parameters` argument to `TorchAdapter._get_fit_args` that decouples SSD construction (model params) from data column extraction (target params). This lets the TL adapter set `_model_space` to include source-only RangeParameters directly, so the SSD naturally covers the full joint feature space -- eliminating the need for the `_expand_ssd_to_joint_space` post-hoc expansion. Overrides `_set_search_space` to add source-only RangeParameters from the joint search space to `_model_space` while preserving target bounds for shared params (Normalize stays anchored to target bounds). At gen time, `self.parameters` is temporarily swapped to target-only so `extract_search_space_digest` sees only params present in the gen-time search space. Deletes `_expand_ssd_to_joint_space` (~90 lines). Differential Revision: D104702983
1 parent 2dd6cc7 commit fda67a6

3 files changed

Lines changed: 243 additions & 193 deletions

File tree

ax/adapter/torch.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ def _get_fit_args(
774774
search_space: SearchSpace,
775775
experiment_data: ExperimentData,
776776
update_outcomes_and_parameters: bool,
777+
data_parameters: list[str] | None = None,
777778
) -> tuple[
778779
list[SupervisedDataset],
779780
list[list[TCandidateMetadata]] | None,
@@ -791,6 +792,11 @@ def _get_fit_args(
791792
update_outcomes_and_parameters: Whether to update `self.outcomes` with
792793
all outcomes found in the observations and `self.parameters` with
793794
all parameters in the search space. Typically only used in `_fit`.
795+
data_parameters: When provided, columns to extract from
796+
``experiment_data``. Defaults to ``self.parameters``. Useful when
797+
the model space is larger than the data (e.g. transfer learning
798+
with heterogeneous search spaces where the data only contains
799+
target columns but the model operates in a joint feature space).
794800
795801
Returns:
796802
The datasets & metadata, extracted from the ``experiment_data``, and the
@@ -818,12 +824,23 @@ def _get_fit_args(
818824
search_space_digest = extract_search_space_digest(
819825
search_space=search_space, param_names=self.parameters
820826
)
827+
extract_params = (
828+
data_parameters if data_parameters is not None else self.parameters
829+
)
830+
# When data_parameters differs from self.parameters, the SSD's
831+
# task_feature indices don't match the data columns. Pass None
832+
# to skip MultiTaskDataset wrapping (the caller handles it).
833+
extraction_ssd = (
834+
None
835+
if data_parameters is not None and data_parameters != self.parameters
836+
else search_space_digest
837+
)
821838
# Convert observations to datasets
822839
datasets, ordered_outcomes, candidate_metadata = self._convert_experiment_data(
823840
experiment_data=experiment_data,
824841
outcomes=self.outcomes,
825-
parameters=self.parameters,
826-
search_space_digest=search_space_digest,
842+
parameters=extract_params,
843+
search_space_digest=extraction_ssd,
827844
)
828845
datasets = self._update_w_aux_exp_datasets(datasets=datasets)
829846

ax/adapter/transfer_learning/adapter.py

Lines changed: 49 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from __future__ import annotations
99

10-
import dataclasses
1110
import warnings
1211
from collections.abc import Mapping, Sequence
1312
from logging import Logger
@@ -196,6 +195,24 @@ def __init__(
196195
default_model_gen_options=default_model_gen_options,
197196
)
198197

198+
def _set_search_space(self, search_space: SearchSpace) -> None:
199+
"""Set search space and model space for transfer learning.
200+
201+
Overrides the base class to add source-only params (as RangeParameters)
202+
to ``_model_space`` while preserving target bounds for shared params.
203+
This ensures the SSD naturally covers the full joint feature space
204+
without needing post-hoc expansion, and Normalize is anchored to target
205+
bounds so target data maps to [0, 1].
206+
"""
207+
self._search_space = search_space.clone()
208+
model_space = search_space.clone()
209+
self._source_only_params: set[str] = set()
210+
for name, param in self.joint_search_space.parameters.items():
211+
if name not in model_space.parameters and isinstance(param, RangeParameter):
212+
model_space.add_parameter(param.clone())
213+
self._source_only_params.add(name)
214+
self._model_space = model_space
215+
199216
def _transform_data(
200217
self,
201218
experiment_data: ExperimentData,
@@ -505,94 +522,18 @@ def _get_task_datasets(
505522
)
506523
return task_datasets
507524

508-
def _expand_ssd_to_joint_space(
525+
def _get_target_data_parameters(
509526
self,
510-
search_space_digest: SearchSpaceDigest,
511-
) -> SearchSpaceDigest:
512-
"""Expand SSD bounds and feature_names to cover the joint search space.
513-
514-
The SSD produced by ``_get_fit_args`` reflects the target search space.
515-
When source experiments have additional parameters, the model operates
516-
in the full joint feature space. This method appends bounds and feature
517-
names for source-only parameters so that input transforms receive
518-
correct full-space bounds.
527+
all_params: list[str],
528+
) -> list[str]:
529+
"""Filter a joint parameter list to target-only params + task feature.
530+
531+
Source-only params (those added by ``_set_search_space`` from the joint
532+
space) are excluded because the target experiment data does not have
533+
those columns. Uses untransformed names, which are stable across
534+
transforms (Range params are never renamed by OneHot, IntToFloat, etc.).
519535
"""
520-
existing_names = set(search_space_digest.feature_names)
521-
extra_names: list[str] = []
522-
extra_bounds: list[tuple[int | float, int | float]] = []
523-
# Only collect parameters absent from the target SSD. Shared
524-
# parameters that appear in both target and source keep the target
525-
# bounds -- source observations outside those bounds will normalize
526-
# outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527-
# for a __target__ task in [0, 1]^D.
528-
for name, param in self.joint_search_space.parameters.items():
529-
if name not in existing_names and isinstance(param, RangeParameter):
530-
extra_names.append(name)
531-
extra_bounds.append((param.lower, param.upper))
532-
if not extra_names:
533-
return search_space_digest
534-
# Insert source-only params before the task feature
535-
task_features = search_space_digest.task_features
536-
if len(task_features) == 1:
537-
tf_idx = task_features[0]
538-
names = list(search_space_digest.feature_names)
539-
bounds = list(search_space_digest.bounds)
540-
# Raise if index-based fields (other than the task feature
541-
# itself) reference indices at or above tf_idx, since we would
542-
# need to shift them when inserting extra params.
543-
for field_name in (
544-
"ordinal_features",
545-
"categorical_features",
546-
"fidelity_features",
547-
):
548-
indices = getattr(search_space_digest, field_name)
549-
if any(i >= tf_idx for i in indices):
550-
raise UnsupportedError(
551-
f"Cannot expand SSD: {field_name} contains index >= {tf_idx}."
552-
)
553-
if any(
554-
i >= tf_idx and i not in task_features
555-
for i in search_space_digest.discrete_choices
556-
):
557-
raise UnsupportedError(
558-
f"Cannot expand SSD: discrete_choices contains index >= {tf_idx}."
559-
)
560-
if search_space_digest.hierarchical_dependencies is not None and any(
561-
i >= tf_idx for i in search_space_digest.hierarchical_dependencies
562-
):
563-
raise UnsupportedError(
564-
"Cannot expand SSD: hierarchical_dependencies contains "
565-
f"index >= {tf_idx}."
566-
)
567-
names[tf_idx:tf_idx] = extra_names
568-
bounds[tf_idx:tf_idx] = extra_bounds
569-
n_extra = len(extra_names)
570-
new_task_features = [tf_idx + n_extra]
571-
new_target_values = dict(search_space_digest.target_values)
572-
if tf_idx in new_target_values:
573-
new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx)
574-
new_discrete = dict(search_space_digest.discrete_choices)
575-
if tf_idx in new_discrete:
576-
new_discrete[new_task_features[0]] = new_discrete.pop(tf_idx)
577-
return dataclasses.replace(
578-
search_space_digest,
579-
feature_names=names,
580-
bounds=bounds,
581-
task_features=new_task_features,
582-
target_values=new_target_values,
583-
discrete_choices=new_discrete,
584-
)
585-
elif len(task_features) == 0:
586-
# No task feature -- just append.
587-
return dataclasses.replace(
588-
search_space_digest,
589-
feature_names=search_space_digest.feature_names + extra_names,
590-
bounds=search_space_digest.bounds + extra_bounds,
591-
)
592-
else:
593-
raise UnsupportedError(
594-
"Multiple task features are not supported in transfer learning."
595-
)
536+
return [p for p in all_params if p not in self._source_only_params]
596537

597538
def _fit(
598539
self,
@@ -610,15 +551,20 @@ def _fit(
610551
if experiment_data.arm_data.empty:
611552
# Temporarily unset self.outcomes to avoid an error in _get_fit_args.
612553
self.outcomes = []
554+
# Pre-compute the joint param ordering (mirrors _get_fit_args logic)
555+
# so we can derive the target-only subset for data extraction.
556+
all_params = list(search_space.parameters.keys())
557+
task_name = Keys.TASK_FEATURE_NAME.value
558+
if task_name in all_params:
559+
all_params.remove(task_name)
560+
all_params.append(task_name)
561+
target_data_params = self._get_target_data_parameters(all_params)
613562
datasets, candidate_metadata, search_space_digest = self._get_fit_args(
614563
search_space=search_space,
615564
experiment_data=experiment_data,
616565
update_outcomes_and_parameters=True,
566+
data_parameters=target_data_params,
617567
)
618-
# Expand SSD bounds to cover source-only params from the joint search
619-
# space. This ensures Normalize (and other input transforms) get bounds
620-
# for the full feature space, not just the target dims.
621-
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
622568
if experiment_data.arm_data.empty:
623569
self.outcomes = outcomes
624570
# Temporarily set datasets to None. We will construct empty datasets
@@ -656,12 +602,13 @@ def _cross_validate(
656602
) -> list[ObservationData]:
657603
if self.parameters is None:
658604
raise ValueError(FIT_MODEL_ERROR.format(action="_cross_validate"))
605+
target_data_params = self._get_target_data_parameters(self.parameters)
659606
datasets, _, search_space_digest = self._get_fit_args(
660607
search_space=search_space,
661608
experiment_data=cv_training_data,
662609
update_outcomes_and_parameters=False,
610+
data_parameters=target_data_params,
663611
)
664-
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
665612
# Add the task feature to SSD, to ensure that a multi-task model is selected.
666613
if len(search_space_digest.task_features) > 1:
667614
raise UnsupportedError(
@@ -728,9 +675,10 @@ def gen(
728675
if fixed_features is None:
729676
fixed_features = ObservationFeatures(parameters={})
730677
fixed_features.parameters.setdefault(name, target_p.value)
731-
# Fix source-only params so the optimizer doesn't search over them.
732-
# Center is a reasonable default; LearnedFeatureImputation overwrites
733-
# these with learned values when configured.
678+
# Fix source-only params that ARE in the search space (e.g. injected
679+
# as FixedParam with a backfill value) so the optimizer doesn't search
680+
# over them. Params NOT in the search space are handled by the model
681+
# internally (HeterogeneousMTGP natively, LFI for MultiTaskGP).
734682
joint_center = self.joint_search_space.compute_naive_center()
735683
for name, param in self.joint_search_space.parameters.items():
736684
if name not in search_space.parameters and isinstance(
@@ -739,6 +687,11 @@ def gen(
739687
if fixed_features is None:
740688
fixed_features = ObservationFeatures(parameters={})
741689
fixed_features.parameters.setdefault(name, joint_center[name])
690+
# At gen time, restrict self.parameters to params that exist in the
691+
# gen-time search space. Source-only params absent from _search_space
692+
# are handled by the model (LFI imputation or HeterogeneousMTGP).
693+
saved_parameters = self.parameters
694+
self.parameters = self._get_target_data_parameters(self.parameters)
742695
generator_run = super().gen(
743696
n=n,
744697
search_space=search_space,
@@ -747,6 +700,7 @@ def gen(
747700
fixed_features=fixed_features,
748701
model_gen_options=model_gen_options,
749702
)
703+
self.parameters = saved_parameters
750704
# Remove the parameters that are not in the target experiment's search
751705
# space, and update candidate_metadata_by_arm_signature to reflect the
752706
# new arm. We use the experiment's search space rather than

0 commit comments

Comments
 (0)