Skip to content

Commit 12ebbd9

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Use search space bounds for Normalize in transfer learning adapter (#5184)
Summary: Pull Request resolved: #5184 The transfer learning adapter explicitly passed `bounds=None` to Normalize, forcing `learn_bounds=True`. This caused Normalize bounds to be learned from data instead of fixed to the search space, resulting in bounds that drift during training and differ between benchmark configs despite identical search spaces. Remove the `bounds=None` override so that `_set_default_bounds` provides the correct search space bounds from the SearchSpaceDigest. Reviewed By: sdaulton Differential Revision: D100669010 fbshipit-source-id: 68615da0ae10a369cddf8a441b4a5b0873594ab2
1 parent 9037c4b commit 12ebbd9

3 files changed

Lines changed: 236 additions & 13 deletions

File tree

ax/adapter/transfer_learning/adapter.py

Lines changed: 110 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99

10+
import dataclasses
1011
import warnings
1112
from collections.abc import Mapping, Sequence
1213
from logging import Logger
@@ -38,7 +39,7 @@
3839
from ax.core.observation import ObservationData, ObservationFeatures
3940
from ax.core.optimization_config import OptimizationConfig
4041
from ax.core.parameter import FixedParameter, RangeParameter
41-
from ax.core.search_space import SearchSpace
42+
from ax.core.search_space import SearchSpace, SearchSpaceDigest
4243
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
4344
from ax.generation_strategy.best_model_selector import (
4445
ReductionCriterion,
@@ -504,6 +505,95 @@ def _get_task_datasets(
504505
)
505506
return task_datasets
506507

508+
def _expand_ssd_to_joint_space(
509+
self,
510+
search_space_digest: SearchSpaceDigest,
511+
) -> SearchSpaceDigest:
512+
"""Expand SSD bounds and feature_names to cover the joint search space.
513+
514+
The SSD produced by ``_get_fit_args`` reflects the target search space.
515+
When source experiments have additional parameters, the model operates
516+
in the full joint feature space. This method appends bounds and feature
517+
names for source-only parameters so that input transforms receive
518+
correct full-space bounds.
519+
"""
520+
existing_names = set(search_space_digest.feature_names)
521+
extra_names: list[str] = []
522+
extra_bounds: list[tuple[int | float, int | float]] = []
523+
# Only collect parameters absent from the target SSD. Shared
524+
# parameters that appear in both target and source keep the target
525+
# bounds -- source observations outside those bounds will normalize
526+
# outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527+
# for a __target__ task in [0, 1]^D.
528+
for name, param in self.joint_search_space.parameters.items():
529+
if name not in existing_names and isinstance(param, RangeParameter):
530+
extra_names.append(name)
531+
extra_bounds.append((param.lower, param.upper))
532+
if not extra_names:
533+
return search_space_digest
534+
# Insert source-only params before the task feature
535+
task_features = search_space_digest.task_features
536+
if len(task_features) == 1:
537+
tf_idx = task_features[0]
538+
names = list(search_space_digest.feature_names)
539+
bounds = list(search_space_digest.bounds)
540+
# Raise if index-based fields (other than the task feature
541+
# itself) reference indices at or above tf_idx, since we would
542+
# need to shift them when inserting extra params.
543+
for field_name in (
544+
"ordinal_features",
545+
"categorical_features",
546+
"fidelity_features",
547+
):
548+
indices = getattr(search_space_digest, field_name)
549+
if any(i >= tf_idx for i in indices):
550+
raise UnsupportedError(
551+
f"Cannot expand SSD: {field_name} contains index >= {tf_idx}."
552+
)
553+
if any(
554+
i >= tf_idx and i not in task_features
555+
for i in search_space_digest.discrete_choices
556+
):
557+
raise UnsupportedError(
558+
f"Cannot expand SSD: discrete_choices contains index >= {tf_idx}."
559+
)
560+
if search_space_digest.hierarchical_dependencies is not None and any(
561+
i >= tf_idx for i in search_space_digest.hierarchical_dependencies
562+
):
563+
raise UnsupportedError(
564+
"Cannot expand SSD: hierarchical_dependencies contains "
565+
f"index >= {tf_idx}."
566+
)
567+
names[tf_idx:tf_idx] = extra_names
568+
bounds[tf_idx:tf_idx] = extra_bounds
569+
n_extra = len(extra_names)
570+
new_task_features = [tf_idx + n_extra]
571+
new_target_values = dict(search_space_digest.target_values)
572+
if tf_idx in new_target_values:
573+
new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx)
574+
new_discrete = dict(search_space_digest.discrete_choices)
575+
if tf_idx in new_discrete:
576+
new_discrete[new_task_features[0]] = new_discrete.pop(tf_idx)
577+
return dataclasses.replace(
578+
search_space_digest,
579+
feature_names=names,
580+
bounds=bounds,
581+
task_features=new_task_features,
582+
target_values=new_target_values,
583+
discrete_choices=new_discrete,
584+
)
585+
elif len(task_features) == 0:
586+
# No task feature -- just append.
587+
return dataclasses.replace(
588+
search_space_digest,
589+
feature_names=search_space_digest.feature_names + extra_names,
590+
bounds=search_space_digest.bounds + extra_bounds,
591+
)
592+
else:
593+
raise UnsupportedError(
594+
"Multiple task features are not supported in transfer learning."
595+
)
596+
507597
def _fit(
508598
self,
509599
search_space: SearchSpace,
@@ -525,6 +615,10 @@ def _fit(
525615
experiment_data=experiment_data,
526616
update_outcomes_and_parameters=True,
527617
)
618+
# Expand SSD bounds to cover source-only params from the joint search
619+
# space. This ensures Normalize (and other input transforms) get bounds
620+
# for the full feature space, not just the target dims.
621+
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
528622
if experiment_data.arm_data.empty:
529623
self.outcomes = outcomes
530624
# Temporarily set datasets to None. We will construct empty datasets
@@ -567,6 +661,7 @@ def _cross_validate(
567661
experiment_data=cv_training_data,
568662
update_outcomes_and_parameters=False,
569663
)
664+
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
570665
# Add the task feature to SSD, to ensure that a multi-task model is selected.
571666
if len(search_space_digest.task_features) > 1:
572667
raise UnsupportedError(
@@ -612,7 +707,7 @@ def gen(
612707
613708
Once the ``GeneratorRun`` is produced, it checks for any fixed parameters
614709
that are not in the target search space and removes them. This is a hack
615-
around limitations of the ``RemoveFixed`` transform. Since we construct the
710+
around limitations of the Ax ``RemoveFixed`` transform. Since we construct the
616711
transforms with the joint space, we end up adding back all fixed parameters
617712
from the joint space rather than adding only the parameters from the
618713
target search space. A proper fix would require passing in the search space
@@ -633,6 +728,17 @@ def gen(
633728
if fixed_features is None:
634729
fixed_features = ObservationFeatures(parameters={})
635730
fixed_features.parameters.setdefault(name, target_p.value)
731+
# Fix source-only params so the optimizer doesn't search over them.
732+
# Center is a reasonable default; LearnedFeatureImputation overwrites
733+
# these with learned values when configured.
734+
joint_center = self.joint_search_space.compute_naive_center()
735+
for name, param in self.joint_search_space.parameters.items():
736+
if name not in search_space.parameters and isinstance(
737+
param, RangeParameter
738+
):
739+
if fixed_features is None:
740+
fixed_features = ObservationFeatures(parameters={})
741+
fixed_features.parameters.setdefault(name, joint_center[name])
636742
generator_run = super().gen(
637743
n=n,
638744
search_space=search_space,
@@ -719,12 +825,8 @@ def transfer_learning_generator_specs_constructor(
719825
selector in case there is model selection enabled.
720826
"""
721827
input_transform_classes: list[type[InputTransform]] = [Normalize]
722-
input_transform_options = {
723-
"Normalize": {
724-
# None for bounds here ensures we do not use bounds from
725-
# the search space digest.
726-
"bounds": None,
727-
}
828+
input_transform_options: dict[str, dict[str, Any]] = {
829+
"Normalize": {},
728830
}
729831
transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans
730832
transform_configs = get_derelativize_config(
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from unittest.mock import MagicMock, PropertyMock
9+
10+
from ax.adapter.transfer_learning.adapter import TransferLearningAdapter
11+
from ax.core.parameter import ParameterType, RangeParameter
12+
from ax.core.search_space import SearchSpace, SearchSpaceDigest
13+
from ax.exceptions.core import UnsupportedError
14+
from ax.utils.common.testutils import TestCase
15+
16+
17+
class ExpandSsdToJointSpaceTest(TestCase):
18+
def setUp(self) -> None:
19+
super().setUp()
20+
self.adapter = MagicMock(spec=TransferLearningAdapter)
21+
22+
def _make_joint_ss(self, params: dict[str, tuple[float, float]]) -> SearchSpace:
23+
return SearchSpace(
24+
parameters=[
25+
RangeParameter(
26+
name=n,
27+
lower=lo,
28+
upper=hi,
29+
parameter_type=ParameterType.FLOAT,
30+
)
31+
for n, (lo, hi) in params.items()
32+
]
33+
)
34+
35+
def test_no_extra_params_returns_unchanged(self) -> None:
36+
type(self.adapter).joint_search_space = PropertyMock(
37+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1)})
38+
)
39+
ssd = SearchSpaceDigest(
40+
feature_names=["x1", "x2", "task"],
41+
bounds=[(0, 1), (0, 1), (0, 2)],
42+
task_features=[2],
43+
target_values={2: 0},
44+
)
45+
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
46+
self.assertIs(result, ssd)
47+
48+
def test_single_task_feature_inserts_before_task(self) -> None:
49+
type(self.adapter).joint_search_space = PropertyMock(
50+
return_value=self._make_joint_ss(
51+
{"x1": (0, 1), "x2": (0, 1), "x3": (-2, 5)}
52+
)
53+
)
54+
ssd = SearchSpaceDigest(
55+
feature_names=["x1", "x2", "task"],
56+
bounds=[(0, 1), (0, 1), (0, 2)],
57+
task_features=[2],
58+
target_values={2: 0},
59+
)
60+
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
61+
self.assertEqual(result.feature_names, ["x1", "x2", "x3", "task"])
62+
self.assertEqual(result.bounds, [(0, 1), (0, 1), (-2, 5), (0, 2)])
63+
self.assertEqual(result.task_features, [3])
64+
self.assertEqual(result.target_values, {3: 0})
65+
66+
def test_zero_task_features_appends(self) -> None:
67+
type(self.adapter).joint_search_space = PropertyMock(
68+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (-1, 3)})
69+
)
70+
ssd = SearchSpaceDigest(
71+
feature_names=["x1"],
72+
bounds=[(0, 1)],
73+
)
74+
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
75+
self.assertEqual(result.feature_names, ["x1", "x2"])
76+
self.assertEqual(result.bounds, [(0, 1), (-1, 3)])
77+
78+
def test_discrete_choices_on_task_feature_shifted(self) -> None:
79+
type(self.adapter).joint_search_space = PropertyMock(
80+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)})
81+
)
82+
ssd = SearchSpaceDigest(
83+
feature_names=["x1", "x2", "task"],
84+
bounds=[(0, 1), (0, 1), (0, 2)],
85+
task_features=[2],
86+
target_values={2: 0},
87+
discrete_choices={2: [0, 1, 2]},
88+
)
89+
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
90+
self.assertEqual(result.discrete_choices, {3: [0, 1, 2]})
91+
self.assertEqual(result.task_features, [3])
92+
93+
def test_hierarchical_dependencies_at_task_idx_raises(self) -> None:
94+
type(self.adapter).joint_search_space = PropertyMock(
95+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)})
96+
)
97+
ssd = SearchSpaceDigest(
98+
feature_names=["x1", "x2", "task"],
99+
bounds=[(0, 1), (0, 1), (0, 2)],
100+
task_features=[2],
101+
target_values={2: 0},
102+
hierarchical_dependencies={2: {0: [1]}},
103+
)
104+
with self.assertRaisesRegex(UnsupportedError, "hierarchical_dependencies"):
105+
TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
106+
107+
def test_multiple_task_features_raises(self) -> None:
108+
type(self.adapter).joint_search_space = PropertyMock(
109+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)})
110+
)
111+
ssd = SearchSpaceDigest(
112+
feature_names=["x1", "task1", "task2"],
113+
bounds=[(0, 1), (0, 1), (0, 1)],
114+
task_features=[1, 2],
115+
)
116+
with self.assertRaisesRegex(UnsupportedError, "Multiple task features"):
117+
TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)

ax/generators/torch/botorch_modular/surrogate.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -738,11 +738,15 @@ def fit(
738738
candidate_metadata=candidate_metadata,
739739
)
740740

741-
# Only update the outcome names and models if the dataset input matches
742-
# the feature names from the search space digest. Otherwise we only
743-
# keep the model within self._submodels as it may be models fitted on
744-
# auxiliary data such as the preference model for BOPE
745-
if set(dataset.feature_names) == feature_names_set:
741+
# Only update the outcome names and models if the dataset input
742+
# matches the feature names from the SSD. In heterogeneous TL,
743+
# _expand_ssd_to_joint_space adds source-only features to the SSD,
744+
# so the target MultiTaskDataset's feature_names will be a strict
745+
# subset -- the missing names are source-only params.
746+
if set(dataset.feature_names) == feature_names_set or (
747+
isinstance(dataset, MultiTaskDataset)
748+
and set(dataset.feature_names).issubset(feature_names_set)
749+
):
746750
models.append(model)
747751
outcome_names.extend(dataset.outcome_names)
748752

0 commit comments

Comments
 (0)