|
32 | 32 | _construct_specified_input_transforms, |
33 | 33 | _extract_model_kwargs, |
34 | 34 | _make_botorch_input_transform, |
| 35 | + _submodel_input_constructor_mtgp, |
35 | 36 | submodel_input_constructor, |
36 | 37 | Surrogate, |
37 | 38 | SurrogateSpec, |
|
59 | 60 | from botorch.models.model import Model, ModelList # noqa: F401 -- used in Mocks. |
60 | 61 | from botorch.models.multitask import MultiTaskGP |
61 | 62 | from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood |
62 | | -from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize |
| 63 | +from botorch.models.transforms.input import ( |
| 64 | + ChainedInputTransform, |
| 65 | + LearnedFeatureImputation, |
| 66 | + Log10, |
| 67 | + Normalize, |
| 68 | +) |
63 | 69 | from botorch.models.transforms.outcome import OutcomeTransform, Standardize |
64 | 70 | from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset |
65 | 71 | from botorch.utils.evaluation import compute_in_sample_model_fit_metric |
@@ -277,6 +283,72 @@ def test__make_botorch_input_transform(self) -> None: |
277 | 283 | self.assertEqual(transform.indices.tolist(), [0]) |
278 | 284 | self.assertEqual(transform.bounds.tolist(), [[1.0], [5.0]]) |
279 | 285 |
|
| 286 | + def test_submodel_input_constructor_mtgp_map_heterogeneous(self) -> None: |
| 287 | + """_submodel_input_constructor_mtgp passes map_heterogeneous_to_full |
| 288 | + to construct_inputs when LFI is configured, enabling zero-padded |
| 289 | + heterogeneous datasets to be used with MultiTaskGP.""" |
| 290 | + ds_target = SupervisedDataset( |
| 291 | + X=torch.tensor([[1.0, 0.0], [2.0, 0.0]]), |
| 292 | + Y=torch.tensor([[1.0], [2.0]]), |
| 293 | + feature_names=["x0", "task"], |
| 294 | + outcome_names=["y_task_0"], |
| 295 | + ) |
| 296 | + ds_source = SupervisedDataset( |
| 297 | + X=torch.tensor([[3.0, 4.0, 1.0], [5.0, 6.0, 1.0]]), |
| 298 | + Y=torch.tensor([[3.0], [4.0]]), |
| 299 | + feature_names=["x0", "x1", "task"], |
| 300 | + outcome_names=["y_task_1"], |
| 301 | + ) |
| 302 | + mt_dataset = MultiTaskDataset( |
| 303 | + datasets=[ds_target, ds_source], |
| 304 | + target_outcome_name="y_task_0", |
| 305 | + task_feature_index=-1, |
| 306 | + ) |
| 307 | + self.assertTrue(mt_dataset.has_heterogeneous_features) |
| 308 | + ssd = SearchSpaceDigest( |
| 309 | + feature_names=["x0", "x1", "task"], |
| 310 | + bounds=[(0.0, 5.0), (0.0, 6.0), (0.0, 1.0)], |
| 311 | + task_features=[2], |
| 312 | + target_values={2: 0.0}, |
| 313 | + ) |
| 314 | + surrogate = Surrogate( |
| 315 | + surrogate_spec=SurrogateSpec( |
| 316 | + model_configs=[ModelConfig(botorch_model_class=MultiTaskGP)] |
| 317 | + ) |
| 318 | + ) |
| 319 | + |
| 320 | + with self.subTest("with LFI — construct_inputs succeeds"): |
| 321 | + config_with_lfi = ModelConfig( |
| 322 | + botorch_model_class=MultiTaskGP, |
| 323 | + input_transform_classes=[Normalize, LearnedFeatureImputation], |
| 324 | + ) |
| 325 | + result = _submodel_input_constructor_mtgp( |
| 326 | + botorch_model_class=MultiTaskGP, |
| 327 | + model_config=config_with_lfi, |
| 328 | + dataset=mt_dataset, |
| 329 | + search_space_digest=ssd, |
| 330 | + surrogate=surrogate, |
| 331 | + ) |
| 332 | + self.assertEqual(result["train_X"].shape[-1], 3) |
| 333 | + |
| 334 | + with self.subTest("without LFI — construct_inputs raises"): |
| 335 | + from botorch.exceptions.errors import ( |
| 336 | + UnsupportedError as BotorchUnsupportedError, |
| 337 | + ) |
| 338 | + |
| 339 | + config_no_lfi = ModelConfig( |
| 340 | + botorch_model_class=MultiTaskGP, |
| 341 | + input_transform_classes=[Normalize], |
| 342 | + ) |
| 343 | + with self.assertRaises(BotorchUnsupportedError): |
| 344 | + _submodel_input_constructor_mtgp( |
| 345 | + botorch_model_class=MultiTaskGP, |
| 346 | + model_config=config_no_lfi, |
| 347 | + dataset=mt_dataset, |
| 348 | + search_space_digest=ssd, |
| 349 | + surrogate=surrogate, |
| 350 | + ) |
| 351 | + |
280 | 352 |
|
281 | 353 | class SurrogateTest(TestCase): |
282 | 354 | def setUp(self, cuda: bool = False) -> None: |
|
0 commit comments