Skip to content

Commit 0ea6723

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Wire LearnedFeatureImputation and map_heterogeneous_to_full for MultiTaskGP (#5192)
Summary: X-link: meta-pytorch/botorch#3296 Automatically configures learned feature imputation for models that pad heterogeneous per-task data to the full joint feature space. Models with native heterogeneity support are excluded from this automatic configuration. Differential Revision: D101841497
1 parent 12ebbd9 commit 0ea6723

4 files changed

Lines changed: 181 additions & 43 deletions

File tree

ax/generators/torch/botorch_modular/surrogate.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections import OrderedDict
1313
from collections.abc import Mapping, Sequence
1414
from copy import deepcopy
15-
from dataclasses import dataclass, field
15+
from dataclasses import dataclass, field, replace
1616
from logging import Logger
1717
from typing import Any, cast
1818

@@ -67,6 +67,7 @@
6767
from botorch.models.transforms.input import (
6868
ChainedInputTransform,
6969
InputTransform,
70+
LearnedFeatureImputation,
7071
Normalize,
7172
)
7273
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
@@ -1253,6 +1254,22 @@ def _submodel_input_constructor_mtgp(
12531254
) -> dict[str, Any]:
12541255
if len(dataset.outcome_names) > 1:
12551256
raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.")
1257+
# If LearnedFeatureImputation is in the model config, tell construct_inputs
1258+
# to map heterogeneous per-task features to the full joint feature space.
1259+
# This must happen before the base call so construct_inputs can handle
1260+
# heterogeneous MultiTaskDatasets without raising.
1261+
uses_lfi = isinstance(model_config.input_transform_classes, list) and any(
1262+
issubclass(cls, LearnedFeatureImputation)
1263+
for cls in model_config.input_transform_classes
1264+
)
1265+
if uses_lfi and "map_heterogeneous_to_full" not in model_config.model_options:
1266+
model_config = replace(
1267+
model_config,
1268+
model_options={
1269+
**model_config.model_options,
1270+
"map_heterogeneous_to_full": True,
1271+
},
1272+
)
12561273
formatted_model_inputs = _submodel_input_constructor_base(
12571274
botorch_model_class=botorch_model_class,
12581275
model_config=model_config,
@@ -1266,9 +1283,12 @@ def _submodel_input_constructor_mtgp(
12661283
# specify output tasks so that model.num_outputs = 1
12671284
# since the model only models a single outcome
12681285
if formatted_model_inputs.get("output_tasks") is None:
1269-
# SSD doesn't use -1, so we need to normalize here
1286+
# SSD doesn't use -1, so we need to normalize here. Use the SSD's bound
1287+
# length since target_values is keyed by SSD column index — for
1288+
# heterogeneous MultiTaskDatasets this differs from the per-task
1289+
# dataset's feature_names length.
12701290
task_feature = none_throws(
1271-
normalize_indices(indices=[task_feature], d=len(dataset.feature_names))
1291+
normalize_indices(indices=[task_feature], d=len(search_space_digest.bounds))
12721292
)[0]
12731293
if (search_space_digest.target_values is not None) and (
12741294
target_value := search_space_digest.target_values.get(task_feature)

ax/generators/torch/botorch_modular/utils.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,32 @@ def use_model_list(
221221
return True
222222

223223

224+
def _ensure_input_transform(
225+
model_config: ModelConfig,
226+
transform_cls: type[InputTransform],
227+
position: int | None = None,
228+
) -> None:
229+
"""Ensure ``transform_cls`` is in ``model_config.input_transform_classes``.
230+
231+
If the user hasn't specified any transforms (``DEFAULT``), initialise the
232+
list with ``[transform_cls]``. Otherwise append (or insert at ``position``)
233+
only when the class isn't already present. Mutates ``model_config``
234+
in-place.
235+
"""
236+
itc = model_config.input_transform_classes
237+
if isinstance(itc, list):
238+
if transform_cls not in itc:
239+
if position is not None:
240+
itc.insert(position, transform_cls)
241+
else:
242+
itc.append(transform_cls)
243+
else:
244+
model_config.input_transform_classes = [transform_cls]
245+
ito = model_config.input_transform_options or {}
246+
ito.setdefault(transform_cls.__name__, {})
247+
model_config.input_transform_options = ito
248+
249+
224250
def copy_model_config_with_default_values(
225251
model_config: ModelConfig,
226252
dataset: SupervisedDataset,
@@ -235,43 +261,15 @@ def copy_model_config_with_default_values(
235261
specified_model_class=model_config_copy.botorch_model_class,
236262
)
237263

238-
# Handle heterogeneous multi-task datasets.
264+
# Handle heterogeneous multi-task datasets: ensure Normalize is present
265+
# and add LearnedFeatureImputation for models that don't handle
266+
# heterogeneity natively.
239267
if isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features:
240-
if model_config_copy.botorch_model_class is HeterogeneousMTGP:
241-
# HeterogeneousMTGP handles heterogeneity natively; just ensure
242-
# Normalize is present (bounds are set later by the TL adapter).
243-
itc = model_config_copy.input_transform_classes
244-
if isinstance(itc, list):
245-
if Normalize not in itc:
246-
itc.insert(0, Normalize)
247-
ito = model_config_copy.input_transform_options or {}
248-
ito.setdefault("Normalize", {"bounds": None})
249-
model_config_copy.input_transform_options = ito
250-
else:
251-
model_config_copy.input_transform_classes = [Normalize]
252-
ito = model_config_copy.input_transform_options or {}
253-
ito.setdefault("Normalize", {"bounds": None})
254-
model_config_copy.input_transform_options = ito
255-
else:
256-
# Other models need Normalize + LFI to pad features via
257-
# map_heterogeneous_to_full.
258-
itc = model_config_copy.input_transform_classes
259-
if isinstance(itc, list):
260-
if Normalize not in itc:
261-
itc.insert(0, Normalize)
262-
ito = model_config_copy.input_transform_options or {}
263-
ito.setdefault("Normalize", {"bounds": None})
264-
model_config_copy.input_transform_options = ito
265-
if LearnedFeatureImputation not in itc:
266-
itc.append(LearnedFeatureImputation)
267-
else:
268-
model_config_copy.input_transform_classes = [
269-
Normalize,
270-
LearnedFeatureImputation,
271-
]
272-
ito = model_config_copy.input_transform_options or {}
273-
ito.setdefault("Normalize", {"bounds": None})
274-
model_config_copy.input_transform_options = ito
268+
_ensure_input_transform(model_config_copy, Normalize, position=0)
269+
if model_config_copy.botorch_model_class is not None and not issubclass(
270+
model_config_copy.botorch_model_class, HeterogeneousMTGP
271+
):
272+
_ensure_input_transform(model_config_copy, LearnedFeatureImputation)
275273

276274
if model_config_copy.mll_class is None:
277275
model_config_copy.mll_class = (

ax/generators/torch/tests/test_surrogate.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
_construct_specified_input_transforms,
3333
_extract_model_kwargs,
3434
_make_botorch_input_transform,
35+
_submodel_input_constructor_mtgp,
3536
submodel_input_constructor,
3637
Surrogate,
3738
SurrogateSpec,
@@ -59,7 +60,12 @@
5960
from botorch.models.model import Model, ModelList # noqa: F401 -- used in Mocks.
6061
from botorch.models.multitask import MultiTaskGP
6162
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
62-
from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize
63+
from botorch.models.transforms.input import (
64+
ChainedInputTransform,
65+
LearnedFeatureImputation,
66+
Log10,
67+
Normalize,
68+
)
6369
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
6470
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
6571
from botorch.utils.evaluation import compute_in_sample_model_fit_metric
@@ -277,6 +283,72 @@ def test__make_botorch_input_transform(self) -> None:
277283
self.assertEqual(transform.indices.tolist(), [0])
278284
self.assertEqual(transform.bounds.tolist(), [[1.0], [5.0]])
279285

286+
def test_submodel_input_constructor_mtgp_map_heterogeneous(self) -> None:
287+
"""_submodel_input_constructor_mtgp passes map_heterogeneous_to_full
288+
to construct_inputs when LFI is configured, enabling zero-padded
289+
heterogeneous datasets to be used with MultiTaskGP."""
290+
ds_target = SupervisedDataset(
291+
X=torch.tensor([[1.0, 0.0], [2.0, 0.0]]),
292+
Y=torch.tensor([[1.0], [2.0]]),
293+
feature_names=["x0", "task"],
294+
outcome_names=["y_task_0"],
295+
)
296+
ds_source = SupervisedDataset(
297+
X=torch.tensor([[3.0, 4.0, 1.0], [5.0, 6.0, 1.0]]),
298+
Y=torch.tensor([[3.0], [4.0]]),
299+
feature_names=["x0", "x1", "task"],
300+
outcome_names=["y_task_1"],
301+
)
302+
mt_dataset = MultiTaskDataset(
303+
datasets=[ds_target, ds_source],
304+
target_outcome_name="y_task_0",
305+
task_feature_index=-1,
306+
)
307+
self.assertTrue(mt_dataset.has_heterogeneous_features)
308+
ssd = SearchSpaceDigest(
309+
feature_names=["x0", "x1", "task"],
310+
bounds=[(0.0, 5.0), (0.0, 6.0), (0.0, 1.0)],
311+
task_features=[2],
312+
target_values={2: 0.0},
313+
)
314+
surrogate = Surrogate(
315+
surrogate_spec=SurrogateSpec(
316+
model_configs=[ModelConfig(botorch_model_class=MultiTaskGP)]
317+
)
318+
)
319+
320+
with self.subTest("with LFI — construct_inputs succeeds"):
321+
config_with_lfi = ModelConfig(
322+
botorch_model_class=MultiTaskGP,
323+
input_transform_classes=[Normalize, LearnedFeatureImputation],
324+
)
325+
result = _submodel_input_constructor_mtgp(
326+
botorch_model_class=MultiTaskGP,
327+
model_config=config_with_lfi,
328+
dataset=mt_dataset,
329+
search_space_digest=ssd,
330+
surrogate=surrogate,
331+
)
332+
self.assertEqual(result["train_X"].shape[-1], 3)
333+
334+
with self.subTest("without LFI — construct_inputs raises"):
335+
from botorch.exceptions.errors import (
336+
UnsupportedError as BotorchUnsupportedError,
337+
)
338+
339+
config_no_lfi = ModelConfig(
340+
botorch_model_class=MultiTaskGP,
341+
input_transform_classes=[Normalize],
342+
)
343+
with self.assertRaises(BotorchUnsupportedError):
344+
_submodel_input_constructor_mtgp(
345+
botorch_model_class=MultiTaskGP,
346+
model_config=config_no_lfi,
347+
dataset=mt_dataset,
348+
search_space_digest=ssd,
349+
surrogate=surrogate,
350+
)
351+
280352

281353
class SurrogateTest(TestCase):
282354
def setUp(self, cuda: bool = False) -> None:

ax/generators/torch/tests/test_utils.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
from botorch.models.model_list_gp_regression import ModelListGP
7171
from botorch.models.multitask import MultiTaskGP
7272
from botorch.models.pairwise_gp import PairwiseGP
73-
from botorch.models.transforms.input import LearnedFeatureImputation, Normalize
73+
from botorch.models.transforms.input import LearnedFeatureImputation, Normalize, Warp
7474
from botorch.posteriors.ensemble import EnsemblePosterior
7575
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
7676
from botorch.utils.types import DEFAULT
@@ -244,7 +244,7 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None:
244244
self.assertEqual(updated_config.input_transform_classes, [Normalize])
245245
self.assertEqual(
246246
none_throws(updated_config.input_transform_options),
247-
{"Normalize": {"bounds": None}},
247+
{"Normalize": {}},
248248
)
249249

250250
# Explicit HeterogeneousMTGP behaves the same.
@@ -257,7 +257,7 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None:
257257
self.assertEqual(updated_config.input_transform_classes, [Normalize])
258258
self.assertEqual(
259259
none_throws(updated_config.input_transform_options),
260-
{"Normalize": {"bounds": None}},
260+
{"Normalize": {}},
261261
)
262262

263263
def test_copy_model_config_mtgp_with_lfi_injection(self) -> None:
@@ -302,6 +302,54 @@ def test_copy_model_config_does_not_add_normalize_for_other_models(self) -> None
302302
self.assertEqual(updated_config.input_transform_classes, DEFAULT)
303303
self.assertEqual(updated_config.input_transform_options, {})
304304

305+
def test_copy_model_config_adds_imputation_for_heterogeneous(self) -> None:
306+
mt_dataset = self._get_heterogeneous_mt_dataset()
307+
ssd = dataclasses.replace(self.search_space_digest, task_features=[-1])
308+
309+
with self.subTest("no_input_transform_classes"):
310+
model_config = ModelConfig(botorch_model_class=MultiTaskGP)
311+
updated_config = copy_model_config_with_default_values(
312+
model_config=model_config,
313+
dataset=mt_dataset,
314+
search_space_digest=ssd,
315+
)
316+
self.assertEqual(updated_config.botorch_model_class, MultiTaskGP)
317+
self.assertEqual(
318+
updated_config.input_transform_classes,
319+
[Normalize, LearnedFeatureImputation],
320+
)
321+
322+
with self.subTest("existing_transform_classes"):
323+
model_config = ModelConfig(
324+
botorch_model_class=MultiTaskGP,
325+
input_transform_classes=[Warp],
326+
input_transform_options={"Warp": {}},
327+
)
328+
updated_config = copy_model_config_with_default_values(
329+
model_config=model_config,
330+
dataset=mt_dataset,
331+
search_space_digest=ssd,
332+
)
333+
self.assertEqual(
334+
updated_config.input_transform_classes,
335+
[Normalize, Warp, LearnedFeatureImputation],
336+
)
337+
338+
with self.subTest("imputation_already_present"):
339+
model_config = ModelConfig(
340+
botorch_model_class=MultiTaskGP,
341+
input_transform_classes=[Normalize, LearnedFeatureImputation],
342+
)
343+
updated_config = copy_model_config_with_default_values(
344+
model_config=model_config,
345+
dataset=mt_dataset,
346+
search_space_digest=ssd,
347+
)
348+
self.assertEqual(
349+
updated_config.input_transform_classes,
350+
[Normalize, LearnedFeatureImputation],
351+
)
352+
305353
def test_choose_model_class_discrete_features(self) -> None:
306354
# With discrete features, use MixedSingleTaskyGP.
307355
self.assertEqual(

0 commit comments

Comments
 (0)