Skip to content

Commit 7e8ebf0

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Default to MultiTaskGP + LearnedFeatureImputation for heterogeneous TL (facebook#5193)
Summary: Switches the default heterogeneous transfer learning model from a specialized per-task kernel model to a standard multi-task GP with learned feature imputation. The previous default model class is marked as deprecated. Differential Revision: D102197137
1 parent faf78a1 commit 7e8ebf0

3 files changed

Lines changed: 19 additions & 12 deletions

File tree

ax/generators/torch/botorch_modular/input_constructors/input_transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,14 @@ def _input_transform_argparse_learned_feature_imputation(
402402
torch.ones(d, dtype=dtype, device=torch_device),
403403
]
404404
)
405+
# The target task is at position 0 (target_dataset is prepended above), so
406+
# at posterior time — when X arrives without a task column — LFI applies
407+
# the target task's imputation pattern.
405408
kwargs: dict[str, Any] = {
406409
"feature_indices": feature_indices,
407410
"d": d,
408411
"task_feature_index": task_feature_index,
412+
"target_task": 0,
409413
"bounds": bounds,
410414
"device": torch_device,
411415
"dtype": dtype,

ax/generators/torch/botorch_modular/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,16 +319,15 @@ def choose_model_class(
319319
)
320320

321321
# Check for heterogeneous multi-task datasets. If a model class was
322-
# explicitly specified, respect it; otherwise default to HeterogeneousMTGP.
322+
# explicitly specified, respect it; otherwise default to MultiTaskGP
323+
# (LearnedFeatureImputation handles missing features).
323324
if (
324325
search_space_digest.task_features
325326
and isinstance(dataset, MultiTaskDataset)
326327
and dataset.has_heterogeneous_features
327328
):
328329
model_class = (
329-
specified_model_class
330-
if specified_model_class is not None
331-
else HeterogeneousMTGP
330+
specified_model_class if specified_model_class is not None else MultiTaskGP
332331
)
333332
logger.debug(f"Chose BoTorch model class: {model_class}.")
334333
return model_class

ax/generators/torch/tests/test_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ def test_choose_model_class_heterogeneous_task_features(self) -> None:
186186
mt_dataset = self._get_heterogeneous_mt_dataset()
187187
ssd = dataclasses.replace(self.search_space_digest, task_features=[-1])
188188

189-
# Default: HeterogeneousMTGP.
189+
# Default: MultiTaskGP (LearnedFeatureImputation handles missing features).
190190
self.assertEqual(
191-
HeterogeneousMTGP,
191+
MultiTaskGP,
192192
choose_model_class(dataset=mt_dataset, search_space_digest=ssd),
193193
)
194194

@@ -233,19 +233,23 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None:
233233
mt_dataset = self._get_heterogeneous_mt_dataset()
234234
ssd = dataclasses.replace(self.search_space_digest, task_features=[-1])
235235

236-
# Default (no model class specified) -> HeterogeneousMTGP.
237-
# LFI is NOT injected; input_transform_classes stays DEFAULT.
236+
# Default (no model class specified) -> MultiTaskGP.
237+
# LFI is injected for MultiTaskGP with heterogeneous data.
238238
updated_config = copy_model_config_with_default_values(
239239
model_config=ModelConfig(),
240240
dataset=mt_dataset,
241241
search_space_digest=ssd,
242242
)
243-
self.assertEqual(updated_config.botorch_model_class, HeterogeneousMTGP)
244-
self.assertEqual(updated_config.input_transform_classes, [Normalize])
243+
self.assertEqual(updated_config.botorch_model_class, MultiTaskGP)
245244
self.assertEqual(
246-
none_throws(updated_config.input_transform_options),
247-
{"Normalize": {}},
245+
updated_config.input_transform_classes,
246+
[Normalize, LearnedFeatureImputation],
248247
)
248+
# LFI is present in transform classes but absent from options; its
249+
# argparse computes kwargs from the dataset at construction time.
250+
ito = none_throws(updated_config.input_transform_options)
251+
self.assertEqual(ito, {"Normalize": {}})
252+
self.assertNotIn("LearnedFeatureImputation", ito)
249253

250254
# Explicit HeterogeneousMTGP behaves the same.
251255
updated_config = copy_model_config_with_default_values(

0 commit comments

Comments
 (0)