Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ax/adapter/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
]

Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY, StandardizeY]
TL_Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY]

# Expected `List[Type[Transform]]` for 2nd anonymous parameter to
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.
Expand Down
21 changes: 19 additions & 2 deletions ax/adapter/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ def _get_fit_args(
search_space: SearchSpace,
experiment_data: ExperimentData,
update_outcomes_and_parameters: bool,
data_parameters: list[str] | None = None,
) -> tuple[
list[SupervisedDataset],
list[list[TCandidateMetadata]] | None,
Expand All @@ -791,6 +792,11 @@ def _get_fit_args(
update_outcomes_and_parameters: Whether to update `self.outcomes` with
all outcomes found in the observations and `self.parameters` with
all parameters in the search space. Typically only used in `_fit`.
data_parameters: When provided, columns to extract from
``experiment_data``. Defaults to ``self.parameters``. Useful when
the model space is larger than the data (e.g. transfer learning
with heterogeneous search spaces where the data only contains
target columns but the model operates in a joint feature space).

Returns:
The datasets & metadata, extracted from the ``experiment_data``, and the
Expand Down Expand Up @@ -818,12 +824,23 @@ def _get_fit_args(
search_space_digest = extract_search_space_digest(
search_space=search_space, param_names=self.parameters
)
extract_params = (
data_parameters if data_parameters is not None else self.parameters
)
# When data_parameters differs from self.parameters, the SSD's
# task_feature indices don't match the data columns. Pass None
# to skip MultiTaskDataset wrapping (the caller handles it).
extraction_ssd = (
None
if data_parameters is not None and data_parameters != self.parameters
else search_space_digest
)
# Convert observations to datasets
datasets, ordered_outcomes, candidate_metadata = self._convert_experiment_data(
experiment_data=experiment_data,
outcomes=self.outcomes,
parameters=self.parameters,
search_space_digest=search_space_digest,
parameters=extract_params,
search_space_digest=extraction_ssd,
)
datasets = self._update_w_aux_exp_datasets(datasets=datasets)

Expand Down
191 changes: 72 additions & 119 deletions ax/adapter/transfer_learning/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from __future__ import annotations

import dataclasses
import warnings
from collections.abc import Mapping, Sequence
from logging import Logger
Expand All @@ -23,7 +22,7 @@
Generators,
GeneratorSetup,
MBM_X_trans,
Y_trans,
TL_Y_trans,
)
from ax.adapter.torch import FIT_MODEL_ERROR, TorchAdapter
from ax.adapter.transfer_learning.utils import get_joint_search_space
Expand All @@ -39,7 +38,7 @@
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import FixedParameter, RangeParameter
from ax.core.search_space import SearchSpace, SearchSpaceDigest
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
from ax.generation_strategy.best_model_selector import (
ReductionCriterion,
Expand All @@ -54,6 +53,7 @@
from ax.utils.common.logger import get_logger
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import InputTransform, Normalize
from botorch.models.transforms.outcome import StratifiedStandardize
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from gpytorch.kernels.kernel import Kernel
from pyre_extensions import assert_is_instance
Expand Down Expand Up @@ -93,6 +93,7 @@ def __init__(
``experiment.auxiliary_experiments_by_purpose`` with type set to
``AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT``.
"""
self._source_only_params: set[str] = set()
transforms = [] if transforms is None else list(transforms)
if MetadataToTask not in transforms:
raise UserInputError(
Expand Down Expand Up @@ -150,22 +151,7 @@ def __init__(
target_search_space=search_space,
)

# Add source-only backfilled params as FixedParameter to the target
# search space so that the compatibility check passes and the model
# space includes these params (FixedToTunable will later convert them
# to RangeParameter using the joint space bounds).
search_space = search_space.clone() # avoid mutating caller's object
target_param_names = set(search_space.parameters.keys())
for name, param in self.joint_search_space.parameters.items():
if name not in target_param_names and param.backfill_value is not None:
search_space.add_parameter(
FixedParameter(
name=name,
parameter_type=param.parameter_type,
value=param.backfill_value,
)
)

# Include backfill param names in filled_params so Phase 1 of
# check_search_space_compatibility passes for target-only params.
filled_params.extend(self.joint_search_space.backfill_values().keys())
Expand Down Expand Up @@ -196,6 +182,28 @@ def __init__(
default_model_gen_options=default_model_gen_options,
)

def _set_search_space(self, search_space: SearchSpace) -> None:
"""Set search space and model space for transfer learning.

Overrides the base class to add source-only params (as RangeParameters)
to ``_model_space`` while preserving target bounds for shared params.
This ensures the SSD naturally covers the full joint feature space
without needing post-hoc expansion, and Normalize is anchored to target
bounds so target data maps to [0, 1].
"""
self._search_space = search_space.clone()
model_space = search_space.clone()
self._source_only_params = set()
for name, param in self.joint_search_space.parameters.items():
if name not in model_space.parameters and isinstance(param, RangeParameter):
model_space.add_parameter(param.clone())
# Only mark as source-only if no backfill value exists.
# Backfilled params have known values (via FillMissingParameters)
# and should be included in data extraction.
if param.backfill_value is None:
self._source_only_params.add(name)
self._model_space = model_space

def _transform_data(
self,
experiment_data: ExperimentData,
Expand Down Expand Up @@ -505,94 +513,18 @@ def _get_task_datasets(
)
return task_datasets

def _expand_ssd_to_joint_space(
def _get_target_data_parameters(
self,
search_space_digest: SearchSpaceDigest,
) -> SearchSpaceDigest:
"""Expand SSD bounds and feature_names to cover the joint search space.

The SSD produced by ``_get_fit_args`` reflects the target search space.
When source experiments have additional parameters, the model operates
in the full joint feature space. This method appends bounds and feature
names for source-only parameters so that input transforms receive
correct full-space bounds.
all_params: list[str],
) -> list[str]:
"""Filter a joint parameter list to target-only params + task feature.

Source-only params (those added by ``_set_search_space`` from the joint
space) are excluded because the target experiment data does not have
those columns. Uses untransformed names, which are stable across
transforms (Range params are never renamed by OneHot, IntToFloat, etc.).
"""
existing_names = set(search_space_digest.feature_names)
extra_names: list[str] = []
extra_bounds: list[tuple[int | float, int | float]] = []
# Only collect parameters absent from the target SSD. Shared
# parameters that appear in both target and source keep the target
# bounds -- source observations outside those bounds will normalize
# outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
# for a __target__ task in [0, 1]^D.
for name, param in self.joint_search_space.parameters.items():
if name not in existing_names and isinstance(param, RangeParameter):
extra_names.append(name)
extra_bounds.append((param.lower, param.upper))
if not extra_names:
return search_space_digest
# Insert source-only params before the task feature
task_features = search_space_digest.task_features
if len(task_features) == 1:
tf_idx = task_features[0]
names = list(search_space_digest.feature_names)
bounds = list(search_space_digest.bounds)
# Raise if index-based fields (other than the task feature
# itself) reference indices at or above tf_idx, since we would
# need to shift them when inserting extra params.
for field_name in (
"ordinal_features",
"categorical_features",
"fidelity_features",
):
indices = getattr(search_space_digest, field_name)
if any(i >= tf_idx for i in indices):
raise UnsupportedError(
f"Cannot expand SSD: {field_name} contains index >= {tf_idx}."
)
if any(
i >= tf_idx and i not in task_features
for i in search_space_digest.discrete_choices
):
raise UnsupportedError(
f"Cannot expand SSD: discrete_choices contains index >= {tf_idx}."
)
if search_space_digest.hierarchical_dependencies is not None and any(
i >= tf_idx for i in search_space_digest.hierarchical_dependencies
):
raise UnsupportedError(
"Cannot expand SSD: hierarchical_dependencies contains "
f"index >= {tf_idx}."
)
names[tf_idx:tf_idx] = extra_names
bounds[tf_idx:tf_idx] = extra_bounds
n_extra = len(extra_names)
new_task_features = [tf_idx + n_extra]
new_target_values = dict(search_space_digest.target_values)
if tf_idx in new_target_values:
new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx)
new_discrete = dict(search_space_digest.discrete_choices)
if tf_idx in new_discrete:
new_discrete[new_task_features[0]] = new_discrete.pop(tf_idx)
return dataclasses.replace(
search_space_digest,
feature_names=names,
bounds=bounds,
task_features=new_task_features,
target_values=new_target_values,
discrete_choices=new_discrete,
)
elif len(task_features) == 0:
# No task feature -- just append.
return dataclasses.replace(
search_space_digest,
feature_names=search_space_digest.feature_names + extra_names,
bounds=search_space_digest.bounds + extra_bounds,
)
else:
raise UnsupportedError(
"Multiple task features are not supported in transfer learning."
)
return [p for p in all_params if p not in self._source_only_params]

def _fit(
self,
Expand All @@ -610,15 +542,20 @@ def _fit(
if experiment_data.arm_data.empty:
# Temporarily unset self.outcomes to avoid an error in _get_fit_args.
self.outcomes = []
# Pre-compute the joint param ordering (mirrors _get_fit_args logic)
# so we can derive the target-only subset for data extraction.
all_params = list(search_space.parameters.keys())
task_name = Keys.TASK_FEATURE_NAME.value
if task_name in all_params:
all_params.remove(task_name)
all_params.append(task_name)
target_data_params = self._get_target_data_parameters(all_params)
datasets, candidate_metadata, search_space_digest = self._get_fit_args(
search_space=search_space,
experiment_data=experiment_data,
update_outcomes_and_parameters=True,
data_parameters=target_data_params,
)
# Expand SSD bounds to cover source-only params from the joint search
# space. This ensures Normalize (and other input transforms) get bounds
# for the full feature space, not just the target dims.
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
if experiment_data.arm_data.empty:
self.outcomes = outcomes
# Temporarily set datasets to None. We will construct empty datasets
Expand Down Expand Up @@ -656,12 +593,13 @@ def _cross_validate(
) -> list[ObservationData]:
if self.parameters is None:
raise ValueError(FIT_MODEL_ERROR.format(action="_cross_validate"))
target_data_params = self._get_target_data_parameters(self.parameters)
datasets, _, search_space_digest = self._get_fit_args(
search_space=search_space,
experiment_data=cv_training_data,
update_outcomes_and_parameters=False,
data_parameters=target_data_params,
)
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
# Add the task feature to SSD, to ensure that a multi-task model is selected.
if len(search_space_digest.task_features) > 1:
raise UnsupportedError(
Expand Down Expand Up @@ -714,23 +652,22 @@ def gen(
to ``RemoveFixed.untransform_observation_features``, which requires updating
the signature of all transforms.
"""
# If a fixed parameter in the target search space, is
# a range parameter in the joint search space, then we
# should set it as a fixed feature here.
# If a fixed parameter in the target search space is a range
# parameter in the joint search space, pin it as a fixed feature.
search_space = search_space or self._search_space
for name, target_p in search_space.parameters.items():
if (
isinstance(target_p, FixedParameter)
and (p := self.joint_search_space.parameters.get(name)) is not None
and isinstance(p, RangeParameter)
):
# add to fixed features
if fixed_features is None:
fixed_features = ObservationFeatures(parameters={})
fixed_features.parameters.setdefault(name, target_p.value)
# Fix source-only params so the optimizer doesn't search over them.
# Center is a reasonable default; LearnedFeatureImputation overwrites
# these with learned values when configured.
# Fix source-only params that ARE in the search space (e.g. injected
# as FixedParam with a backfill value) so the optimizer doesn't search
# over them. Params NOT in the search space are handled by the model
# internally (HeterogeneousMTGP natively, LFI for MultiTaskGP).
joint_center = self.joint_search_space.compute_naive_center()
for name, param in self.joint_search_space.parameters.items():
if name not in search_space.parameters and isinstance(
Expand All @@ -739,6 +676,20 @@ def gen(
if fixed_features is None:
fixed_features = ObservationFeatures(parameters={})
fixed_features.parameters.setdefault(name, joint_center[name])
# At gen time, optimize over the target search space only.
# _source_only_params covers params without backfill (stable
# untransformed names). Backfilled source-only params are identified
# by checking joint SS backfill_values — their names are also stable
# (RangeParameters are never renamed by transforms).
saved_parameters = self.parameters
backfilled_source_only = {
name
for name, param in self.joint_search_space.parameters.items()
if name not in self._experiment.search_space.parameters
and param.backfill_value is not None
}
exclude = self._source_only_params | backfilled_source_only
self.parameters = [p for p in self.parameters if p not in exclude]
generator_run = super().gen(
n=n,
search_space=search_space,
Expand All @@ -747,6 +698,7 @@ def gen(
fixed_features=fixed_features,
model_gen_options=model_gen_options,
)
self.parameters = saved_parameters
# Remove the parameters that are not in the target experiment's search
# space, and update candidate_metadata_by_arm_signature to reflect the
# new arm. We use the experiment's search space rather than
Expand Down Expand Up @@ -793,7 +745,7 @@ def transfer_learning_generator_specs_constructor(
Args:
model_class: The MultiTask BoTorch Model to use in the BOTL.
transform: Optional list of transforms to use in the Adapter.
Defaults to MBM_X_trans + [MetadataToTask] + Y_trans.
Defaults to MBM_X_trans + [MetadataToTask] + TL_Y_trans.
jit_compile: Whether to use jit compilation in Pyro when the fully Bayesian
model is used.
torch_device: What torch device to use (defaults to None, i.e. falls back to
Expand Down Expand Up @@ -828,7 +780,7 @@ def transfer_learning_generator_specs_constructor(
input_transform_options: dict[str, dict[str, Any]] = {
"Normalize": {},
}
transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans
transforms = transforms or MBM_X_trans + [MetadataToTask] + TL_Y_trans
transform_configs = get_derelativize_config(
derelativize_with_raw_status_quo=derelativize_with_raw_status_quo
)
Expand All @@ -846,6 +798,7 @@ def transfer_learning_generator_specs_constructor(
botorch_model_class=model_class,
model_options=botorch_model_kwargs or {},
input_transform_classes=input_transform_classes,
outcome_transform_classes=[StratifiedStandardize],
input_transform_options=input_transform_options,
mll_options=mll_kwargs,
covar_module_class=covar_module_class,
Expand Down Expand Up @@ -887,5 +840,5 @@ def transfer_learning_generator_specs_constructor(
GENERATOR_KEY_TO_GENERATOR_SETUP["BOTL"] = GeneratorSetup(
adapter_class=TransferLearningAdapter,
generator_class=BoTorchGenerator,
transforms=MBM_X_trans + [MetadataToTask] + Y_trans,
transforms=MBM_X_trans + [MetadataToTask] + TL_Y_trans,
)
Loading
Loading