Skip to content

Commit 2f8be3a

Browse files
committed
Make BotorchKernelFactory support parameter selection
1 parent 3c5ec16 commit 2f8be3a

1 file changed

Lines changed: 43 additions & 11 deletions

File tree

  • baybe/surrogates/gaussian_process/presets

baybe/surrogates/gaussian_process/presets/botorch.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,61 +3,93 @@
33
from __future__ import annotations
44

55
import gc
6-
from typing import TYPE_CHECKING
6+
from itertools import chain
7+
from typing import TYPE_CHECKING, ClassVar
78

89
from attrs import define
9-
from gpytorch.kernels import Kernel as GPyTorchKernel
1010
from typing_extensions import override
1111

1212
from baybe.kernels.base import Kernel
13+
from baybe.parameters.enum import ParameterKind
1314
from baybe.searchspace.core import SearchSpace
1415
from baybe.surrogates.gaussian_process.components import LikelihoodFactoryProtocol
1516
from baybe.surrogates.gaussian_process.components._gpytorch import (
1617
make_botorch_multitask_likelihood,
1718
)
1819
from baybe.surrogates.gaussian_process.components.kernel import (
1920
ICMKernelFactory,
20-
KernelFactoryProtocol,
21+
_KernelFactory,
2122
)
2223
from baybe.surrogates.gaussian_process.components.mean import MeanFactoryProtocol
2324

2425
if TYPE_CHECKING:
26+
from gpytorch.kernels import Kernel as GPyTorchKernel
2527
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
2628
from gpytorch.means import Mean as GPyTorchMean
2729
from torch import Tensor
2830

2931

3032
@define
31-
class BotorchKernelFactory(KernelFactoryProtocol):
33+
class BotorchKernelFactory(_KernelFactory):
3234
"""A factory providing BoTorch kernels."""
3335

36+
_uses_parameter_names: ClassVar[bool] = True
37+
# See base class.
38+
39+
supported_parameter_kinds: ClassVar[ParameterKind] = (
40+
ParameterKind.REGULAR | ParameterKind.TASK
41+
)
42+
# See base class.
43+
3444
@override
35-
def __call__(
45+
def _make(
3646
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
3747
) -> Kernel | GPyTorchKernel:
3848
from botorch.models.kernels.positive_index import PositiveIndexKernel
3949
from botorch.models.utils.gpytorch_modules import (
4050
get_covar_module_with_dim_scaled_prior,
4151
)
4252

43-
if searchspace.n_tasks == 1:
53+
parameter_names = self.get_parameter_names(searchspace)
54+
55+
# Resolve parameter names to active dimension indices
56+
active_dims: list[int] | None
57+
if parameter_names is not None:
58+
active_dims = list(
59+
chain.from_iterable(
60+
searchspace.get_comp_rep_parameter_indices(name)
61+
for name in parameter_names
62+
)
63+
)
64+
ard_num_dims = len(active_dims)
65+
else:
66+
active_dims = None
67+
ard_num_dims = len(searchspace.comp_rep_columns)
68+
69+
# Determine if the selected parameters include a task parameter
70+
task_idx = searchspace.task_idx
71+
is_multitask = task_idx is not None and (
72+
active_dims is None or task_idx in active_dims
73+
)
74+
75+
if not is_multitask:
4476
return get_covar_module_with_dim_scaled_prior(
45-
ard_num_dims=len(searchspace.comp_rep_columns), active_dims=None
77+
ard_num_dims=ard_num_dims, active_dims=active_dims
4678
)
4779

48-
assert searchspace.task_idx is not None
80+
assert task_idx is not None
4981
base_idcs = [
5082
idx
51-
for idx in range(len(searchspace.comp_rep_columns))
52-
if idx != searchspace.task_idx
83+
for idx in (active_dims or range(len(searchspace.comp_rep_columns)))
84+
if idx != task_idx
5385
]
5486
base = get_covar_module_with_dim_scaled_prior(
5587
ard_num_dims=len(base_idcs), active_dims=base_idcs
5688
)
5789
index_kernel = PositiveIndexKernel(
5890
num_tasks=searchspace.n_tasks,
5991
rank=searchspace.n_tasks,
60-
active_dims=[searchspace.task_idx],
92+
active_dims=[task_idx],
6193
)
6294
return ICMKernelFactory(base, index_kernel)(searchspace, train_x, train_y)
6395

0 commit comments

Comments
 (0)