|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import gc |
6 | | -from typing import TYPE_CHECKING |
| 6 | +from itertools import chain |
| 7 | +from typing import TYPE_CHECKING, ClassVar |
7 | 8 |
|
8 | 9 | from attrs import define |
9 | | -from gpytorch.kernels import Kernel as GPyTorchKernel |
10 | 10 | from typing_extensions import override |
11 | 11 |
|
12 | 12 | from baybe.kernels.base import Kernel |
| 13 | +from baybe.parameters.enum import ParameterKind |
13 | 14 | from baybe.searchspace.core import SearchSpace |
14 | 15 | from baybe.surrogates.gaussian_process.components import LikelihoodFactoryProtocol |
15 | 16 | from baybe.surrogates.gaussian_process.components._gpytorch import ( |
16 | 17 | make_botorch_multitask_likelihood, |
17 | 18 | ) |
18 | 19 | from baybe.surrogates.gaussian_process.components.kernel import ( |
19 | 20 | ICMKernelFactory, |
20 | | - KernelFactoryProtocol, |
| 21 | + _KernelFactory, |
21 | 22 | ) |
22 | 23 | from baybe.surrogates.gaussian_process.components.mean import MeanFactoryProtocol |
23 | 24 |
|
24 | 25 | if TYPE_CHECKING: |
| 26 | + from gpytorch.kernels import Kernel as GPyTorchKernel |
25 | 27 | from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood |
26 | 28 | from gpytorch.means import Mean as GPyTorchMean |
27 | 29 | from torch import Tensor |
28 | 30 |
|
29 | 31 |
|
30 | 32 | @define |
31 | | -class BotorchKernelFactory(KernelFactoryProtocol): |
| 33 | +class BotorchKernelFactory(_KernelFactory): |
32 | 34 | """A factory providing BoTorch kernels.""" |
33 | 35 |
|
| 36 | + _uses_parameter_names: ClassVar[bool] = True |
| 37 | + # See base class. |
| 38 | + |
| 39 | + supported_parameter_kinds: ClassVar[ParameterKind] = ( |
| 40 | + ParameterKind.REGULAR | ParameterKind.TASK |
| 41 | + ) |
| 42 | + # See base class. |
| 43 | + |
34 | 44 | @override |
35 | | - def __call__( |
| 45 | + def _make( |
36 | 46 | self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor |
37 | 47 | ) -> Kernel | GPyTorchKernel: |
38 | 48 | from botorch.models.kernels.positive_index import PositiveIndexKernel |
39 | 49 | from botorch.models.utils.gpytorch_modules import ( |
40 | 50 | get_covar_module_with_dim_scaled_prior, |
41 | 51 | ) |
42 | 52 |
|
43 | | - if searchspace.n_tasks == 1: |
| 53 | + parameter_names = self.get_parameter_names(searchspace) |
| 54 | + |
| 55 | + # Resolve parameter names to active dimension indices |
| 56 | + active_dims: list[int] | None |
| 57 | + if parameter_names is not None: |
| 58 | + active_dims = list( |
| 59 | + chain.from_iterable( |
| 60 | + searchspace.get_comp_rep_parameter_indices(name) |
| 61 | + for name in parameter_names |
| 62 | + ) |
| 63 | + ) |
| 64 | + ard_num_dims = len(active_dims) |
| 65 | + else: |
| 66 | + active_dims = None |
| 67 | + ard_num_dims = len(searchspace.comp_rep_columns) |
| 68 | + |
| 69 | + # Determine if the selected parameters include a task parameter |
| 70 | + task_idx = searchspace.task_idx |
| 71 | + is_multitask = task_idx is not None and ( |
| 72 | + active_dims is None or task_idx in active_dims |
| 73 | + ) |
| 74 | + |
| 75 | + if not is_multitask: |
44 | 76 | return get_covar_module_with_dim_scaled_prior( |
45 | | - ard_num_dims=len(searchspace.comp_rep_columns), active_dims=None |
| 77 | + ard_num_dims=ard_num_dims, active_dims=active_dims |
46 | 78 | ) |
47 | 79 |
|
48 | | - assert searchspace.task_idx is not None |
| 80 | + assert task_idx is not None |
49 | 81 | base_idcs = [ |
50 | 82 | idx |
51 | | - for idx in range(len(searchspace.comp_rep_columns)) |
52 | | - if idx != searchspace.task_idx |
| 83 | + for idx in (active_dims or range(len(searchspace.comp_rep_columns))) |
| 84 | + if idx != task_idx |
53 | 85 | ] |
54 | 86 | base = get_covar_module_with_dim_scaled_prior( |
55 | 87 | ard_num_dims=len(base_idcs), active_dims=base_idcs |
56 | 88 | ) |
57 | 89 | index_kernel = PositiveIndexKernel( |
58 | 90 | num_tasks=searchspace.n_tasks, |
59 | 91 | rank=searchspace.n_tasks, |
60 | | - active_dims=[searchspace.task_idx], |
| 92 | + active_dims=[task_idx], |
61 | 93 | ) |
62 | 94 | return ICMKernelFactory(base, index_kernel)(searchspace, train_x, train_y) |
63 | 95 |
|
|
0 commit comments