Skip to content

Commit 8046f90

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Use StratifiedStandardize for per-task Y standardization in TL
Summary: Adds per-task outcome standardization to the transfer learning adapter, ensuring each task's observations are standardized independently rather than jointly. Updates the default transform pipeline to use TL-specific outcome transforms. Differential Revision: D102197139
1 parent 8971a90 commit 8046f90

2 files changed

Lines changed: 4 additions & 0 deletions

File tree

ax/adapter/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
]
140140

141141
Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY, StandardizeY]
142+
TL_Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY]
142143

143144
# Expected `List[Type[Transform]]` for 2nd anonymous parameter to
144145
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.

ax/adapter/transfer_learning/adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
GeneratorSetup,
2525
MBM_X_trans,
2626
Y_trans,
27+
Y_trans,
2728
)
2829
from ax.adapter.torch import FIT_MODEL_ERROR, TorchAdapter
2930
from ax.adapter.transfer_learning.utils import get_joint_search_space
@@ -54,6 +55,7 @@
5455
from ax.utils.common.logger import get_logger
5556
from botorch.models.multitask import MultiTaskGP
5657
from botorch.models.transforms.input import InputTransform, Normalize
58+
from botorch.models.transforms.outcome import StratifiedStandardize
5759
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
5860
from gpytorch.kernels.kernel import Kernel
5961
from pyre_extensions import assert_is_instance
@@ -846,6 +848,7 @@ def transfer_learning_generator_specs_constructor(
846848
botorch_model_class=model_class,
847849
model_options=botorch_model_kwargs or {},
848850
input_transform_classes=input_transform_classes,
851+
outcome_transform_classes=[StratifiedStandardize],
849852
input_transform_options=input_transform_options,
850853
mll_options=mll_kwargs,
851854
covar_module_class=covar_module_class,

0 commit comments

Comments
 (0)