Skip to content

Commit 5258a8f

Browse files
authored
GP Criterion (#789)
DevPR, parent is #745 Makes the optimization criterion of the `GaussianProcessSurrogate` model configurable, in the form of a new `FitCriterion` enum. Potentially, this might be generalized to a class-based approach in the future if more configuration options are required, but for now the simpler solution serves all existing use cases.
2 parents f1e7a0b + f11e111 commit 5258a8f

11 files changed

Lines changed: 226 additions & 49 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
- Gaussian process component factories
1616
- Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process
1717
components, enabling full low-level customization
18+
- Configurable fitting criterion for Gaussian process hyperparameter optimization
1819
- Factories for all Gaussian process components
1920
- `CHEN`, `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
2021
- `TypeSelector` and `NameSelector` classes for parameter selection in kernel factories

baybe/surrogates/gaussian_process/components/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
"""Gaussian process surrogate components."""
22

3+
from baybe.surrogates.gaussian_process.components.fit_criterion import (
4+
FitCriterion,
5+
FitCriterionFactoryProtocol,
6+
PlainFitCriterionFactory,
7+
)
38
from baybe.surrogates.gaussian_process.components.kernel import (
49
KernelFactoryProtocol,
510
PlainKernelFactory,
@@ -15,6 +20,10 @@
1520
)
1621

1722
__all__ = [
23+
# Fit Criterion
24+
"FitCriterion",
25+
"FitCriterionFactoryProtocol",
26+
"PlainFitCriterionFactory",
1827
# Kernel
1928
"KernelFactoryProtocol",
2029
"PlainKernelFactory",
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Fitting criteria for the Gaussian process surrogate."""
2+
3+
from __future__ import annotations
4+
5+
from enum import Enum
6+
from typing import TYPE_CHECKING
7+
8+
from attrs import define
9+
from typing_extensions import override
10+
11+
if TYPE_CHECKING:
12+
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
13+
from gpytorch.mlls import MarginalLogLikelihood
14+
from gpytorch.models import GP as GPyTorchModel
15+
from torch import Tensor
16+
17+
from baybe.searchspace.core import SearchSpace
18+
19+
20+
class FitCriterion(Enum):
21+
"""Available fitting criteria for GP hyperparameter optimization."""
22+
23+
MARGINAL_LOG_LIKELIHOOD = "MARGINAL_LOG_LIKELIHOOD"
24+
"""Exact marginal log-likelihood."""
25+
26+
LEAVE_ONE_OUT_PSEUDOLIKELIHOOD = "LEAVE_ONE_OUT_PSEUDOLIKELIHOOD"
27+
"""Leave-one-out cross-validation pseudo-likelihood."""
28+
29+
def to_gpytorch(
30+
self, likelihood: GPyTorchLikelihood, model: GPyTorchModel
31+
) -> MarginalLogLikelihood:
32+
"""Create the corresponding GPyTorch MLL object."""
33+
import gpytorch
34+
35+
mll_class = {
36+
FitCriterion.MARGINAL_LOG_LIKELIHOOD: gpytorch.ExactMarginalLogLikelihood,
37+
FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD: gpytorch.mlls.LeaveOneOutPseudoLikelihood, # noqa: E501
38+
}[self]
39+
return mll_class(likelihood, model)
40+
41+
42+
# Delayed import to avoid circular dependency
43+
from baybe.surrogates.gaussian_process.components.generic import ( # noqa: E402
44+
GPComponentFactoryProtocol,
45+
PlainGPComponentFactory,
46+
)
47+
48+
FitCriterionFactoryProtocol = GPComponentFactoryProtocol[FitCriterion]
49+
"""A protocol defining the interface for fit criterion factories."""
50+
51+
PlainFitCriterionFactory = PlainGPComponentFactory[FitCriterion]
52+
"""A trivial factory that returns a fixed fit criterion."""
53+
54+
55+
@define
56+
class _MLLForNonTLFitCriterionFactory(FitCriterionFactoryProtocol):
57+
"""A fit criterion factory switching between MLL and BayBE default.
58+
59+
In transfer learning contexts, delegates to
60+
:class:`baybe.surrogates.gaussian_process.presets.baybe.BayBEFitCriterionFactory`.
61+
"""
62+
63+
@override
64+
def __call__(
65+
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
66+
) -> FitCriterion:
67+
if searchspace.task_idx is None:
68+
return FitCriterion.MARGINAL_LOG_LIKELIHOOD
69+
70+
from baybe.surrogates.gaussian_process.presets.baybe import (
71+
BayBEFitCriterionFactory,
72+
)
73+
74+
return BayBEFitCriterionFactory()(searchspace, train_x, train_y)

baybe/surrogates/gaussian_process/components/generic.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from baybe.searchspace import SearchSpace
1515
from baybe.serialization.core import block_serialization_hook, converter
1616
from baybe.serialization.mixin import SerialMixin
17+
from baybe.surrogates.gaussian_process.components.fit_criterion import FitCriterion
1718

18-
BayBEGPComponent: TypeAlias = Kernel
19+
BayBEGPComponent: TypeAlias = Kernel | FitCriterion
1920

2021
if TYPE_CHECKING:
2122
from gpytorch.kernels import Kernel as GPyTorchKernel
@@ -44,15 +45,24 @@ class GPComponentType(Enum):
4445
LIKELIHOOD = "LIKELIHOOD"
4546
"""Gaussian process likelihood."""
4647

48+
CRITERION = "CRITERION"
49+
"""Gaussian process fitting criterion."""
50+
4751
def get_types(self) -> tuple[type, ...]:
4852
"""Get the accepted BayBE and GPyTorch types for this component."""
49-
types = []
53+
types: list[type[GPComponent]] = []
5054

5155
# Add BayBE type if applicable
5256
if self is GPComponentType.KERNEL:
5357
from baybe.kernels.base import Kernel
5458

5559
types.append(Kernel)
60+
elif self is GPComponentType.CRITERION:
61+
from baybe.surrogates.gaussian_process.components.fit_criterion import (
62+
FitCriterion,
63+
)
64+
65+
types.append(FitCriterion)
5666

5767
# Add GPyTorch type if available
5868
if sys.modules.get("gpytorch") is not None:
@@ -85,7 +95,7 @@ def _is_gpytorch_component_class(obj: Any, /) -> bool:
8595

8696
def _validate_component(instance: Any, attribute: Attribute, value: Any) -> None:
8797
"""Validate that an object is a BayBE or a GPyTorch GP component."""
88-
if isinstance(value, Kernel) or _is_gpytorch_component_class(type(value)):
98+
if isinstance(value, BayBEGPComponent) or _is_gpytorch_component_class(type(value)):
8999
return
90100

91101
raise TypeError(

baybe/surrogates/gaussian_process/core.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
from baybe.parameters.categorical import TaskParameter
2020
from baybe.searchspace.core import SearchSpace
2121
from baybe.surrogates.base import Surrogate
22+
from baybe.surrogates.gaussian_process.components.fit_criterion import (
23+
FitCriterion,
24+
FitCriterionFactoryProtocol,
25+
)
2226
from baybe.surrogates.gaussian_process.components.generic import (
2327
GPComponentType,
2428
to_component_factory,
@@ -35,6 +39,7 @@
3539
GaussianProcessPreset,
3640
)
3741
from baybe.surrogates.gaussian_process.presets.baybe import (
42+
BayBEFitCriterionFactory,
3843
BayBEKernelFactory,
3944
BayBELikelihoodFactory,
4045
BayBEMeanFactory,
@@ -178,6 +183,21 @@ class GaussianProcessSurrogate(Surrogate):
178183
* :class:`gpytorch.likelihoods.Likelihood`
179184
"""
180185

186+
criterion_factory: FitCriterionFactoryProtocol = field(
187+
alias="criterion_or_factory",
188+
factory=BayBEFitCriterionFactory,
189+
converter=partial( # type: ignore[misc]
190+
to_component_factory, component_type=GPComponentType.CRITERION
191+
),
192+
validator=is_callable(),
193+
)
194+
"""The fitting criterion for Gaussian process hyperparameter optimization.
195+
196+
Accepts:
197+
* :class:`.components.fit_criterion.FitCriterion`
198+
* :class:`.components.fit_criterion.FitCriterionFactoryProtocol`
199+
"""
200+
181201
# TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently
182202
# omitted due to: https://github.com/python-attrs/cattrs/issues/531
183203
_model = field(init=False, default=None, eq=False)
@@ -195,6 +215,7 @@ def from_preset(
195215
likelihood_or_factory: LikelihoodFactoryProtocol
196216
| GPyTorchLikelihood
197217
| None = None,
218+
criterion_or_factory: FitCriterion | FitCriterionFactoryProtocol | None = None,
198219
) -> Self:
199220
"""Create a Gaussian process surrogate from one of the defined presets."""
200221
preset = GaussianProcessPreset(preset)
@@ -204,13 +225,12 @@ def from_preset(
204225
)
205226
module = importlib.import_module(module_name)
206227

207-
kernel = kernel_or_factory or getattr(module, "PresetKernelFactory")()
208-
mean = mean_or_factory or getattr(module, "PresetMeanFactory")()
209-
likelihood = (
210-
likelihood_or_factory or getattr(module, "PresetLikelihoodFactory")()
211-
)
228+
kernel = kernel_or_factory or getattr(module, "KERNEL_FACTORY")
229+
mean = mean_or_factory or getattr(module, "MEAN_FACTORY")
230+
likelihood = likelihood_or_factory or getattr(module, "LIKELIHOOD_FACTORY")
231+
criterion = criterion_or_factory or getattr(module, "FIT_CRITERION_FACTORY")
212232

213-
return cls(kernel, mean, likelihood)
233+
return cls(kernel, mean, likelihood, criterion)
214234

215235
@override
216236
def to_botorch(self) -> GPyTorchModel:
@@ -237,7 +257,6 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior:
237257
@override
238258
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
239259
import botorch
240-
import gpytorch
241260
from botorch.models.transforms import Normalize, Standardize
242261

243262
assert self._searchspace is not None # provided by base class
@@ -281,6 +300,9 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
281300
### Likelihood
282301
likelihood = self.likelihood_factory(context.searchspace, train_x, train_y)
283302

303+
### Criterion
304+
criterion = self.criterion_factory(context.searchspace, train_x, train_y)
305+
284306
### Model construction and fitting
285307
self._model = botorch.models.SingleTaskGP(
286308
train_x,
@@ -291,18 +313,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
291313
covar_module=kernel,
292314
likelihood=likelihood,
293315
)
294-
295-
# TODO: This is still a temporary workaround to avoid overfitting seen in
296-
# low-dimensional TL cases. More robust settings are being researched.
297-
if context.n_task_dimensions > 0:
298-
mll = gpytorch.mlls.LeaveOneOutPseudoLikelihood(
299-
self._model.likelihood, self._model
300-
)
301-
else:
302-
mll = gpytorch.ExactMarginalLogLikelihood(
303-
self._model.likelihood, self._model
304-
)
305-
316+
mll = criterion.to_gpytorch(self._model.likelihood, self._model)
306317
botorch.fit.fit_gpytorch_mll(mll)
307318

308319
@override
@@ -311,6 +322,9 @@ def __str__(self) -> str:
311322
to_string("Kernel factory", self.kernel_factory, single_line=True),
312323
to_string("Mean factory", self.mean_factory, single_line=True),
313324
to_string("Likelihood factory", self.likelihood_factory, single_line=True),
325+
to_string(
326+
"Fit criterion factory", self.criterion_factory, single_line=True
327+
),
314328
]
315329
return to_string(super().__str__(), *fields)
316330

baybe/surrogates/gaussian_process/presets/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,60 @@
11
"""Gaussian process surrogate presets."""
22

3+
# Criterion
4+
from baybe.surrogates.gaussian_process.components.fit_criterion import FitCriterion
5+
36
# Default preset
47
from baybe.surrogates.gaussian_process.presets.baybe import (
8+
BayBEFitCriterionFactory,
59
BayBEKernelFactory,
610
BayBELikelihoodFactory,
711
BayBEMeanFactory,
812
)
913

1014
# Chen preset
11-
from baybe.surrogates.gaussian_process.presets.chen import CHENKernelFactory
15+
from baybe.surrogates.gaussian_process.presets.chen import (
16+
CHENFitCriterionFactory,
17+
CHENKernelFactory,
18+
)
1219

1320
# Core
1421
from baybe.surrogates.gaussian_process.presets.core import GaussianProcessPreset
1522

1623
# EDBO preset
1724
from baybe.surrogates.gaussian_process.presets.edbo import (
25+
EDBOFitCriterionFactory,
1826
EDBOKernelFactory,
1927
EDBOLikelihoodFactory,
2028
EDBOMeanFactory,
2129
)
2230

2331
# Smoothed EDBO preset
2432
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
33+
SmoothedEDBOFitCriterionFactory,
2534
SmoothedEDBOKernelFactory,
2635
SmoothedEDBOLikelihoodFactory,
2736
SmoothedEDBOMeanFactory,
2837
)
2938

3039
__all__ = [
3140
# Core
41+
"FitCriterion",
3242
"GaussianProcessPreset",
3343
# Default BayBE preset
44+
"BayBEFitCriterionFactory",
3445
"BayBEKernelFactory",
3546
"BayBELikelihoodFactory",
3647
"BayBEMeanFactory",
3748
# Chen preset
49+
"CHENFitCriterionFactory",
3850
"CHENKernelFactory",
3951
# EDBO preset
52+
"EDBOFitCriterionFactory",
4053
"EDBOKernelFactory",
4154
"EDBOLikelihoodFactory",
4255
"EDBOMeanFactory",
4356
# Smoothed EDBO preset
57+
"SmoothedEDBOFitCriterionFactory",
4458
"SmoothedEDBOKernelFactory",
4559
"SmoothedEDBOLikelihoodFactory",
4660
"SmoothedEDBOMeanFactory",

baybe/surrogates/gaussian_process/presets/baybe.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
to_parameter_selector,
1818
)
1919
from baybe.searchspace.core import SearchSpace
20+
from baybe.surrogates.gaussian_process.components.fit_criterion import (
21+
FitCriterion,
22+
FitCriterionFactoryProtocol,
23+
)
2024
from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory
2125
from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory
2226
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
@@ -85,7 +89,24 @@ def _make(
8589
BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory
8690
"""The factory providing the default likelihood for Gaussian process surrogates."""
8791

88-
# Aliases for generic preset imports
89-
PresetKernelFactory = BayBEKernelFactory
90-
PresetMeanFactory = BayBEMeanFactory
91-
PresetLikelihoodFactory = BayBELikelihoodFactory
92+
93+
@define
94+
class BayBEFitCriterionFactory(FitCriterionFactoryProtocol):
95+
"""The factory providing the default fitting criterion for Gaussian process surrogates.""" # noqa: E501
96+
97+
@override
98+
def __call__(
99+
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
100+
) -> FitCriterion:
101+
return (
102+
FitCriterion.MARGINAL_LOG_LIKELIHOOD
103+
if searchspace.task_idx is None
104+
else FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD
105+
)
106+
107+
108+
# Preset defaults
109+
KERNEL_FACTORY = BayBEKernelFactory()
110+
MEAN_FACTORY = BayBEMeanFactory()
111+
LIKELIHOOD_FACTORY = BayBELikelihoodFactory()
112+
FIT_CRITERION_FACTORY = BayBEFitCriterionFactory()

baybe/surrogates/gaussian_process/presets/chen.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
to_parameter_selector,
1919
)
2020
from baybe.priors.basic import GammaPrior
21+
from baybe.surrogates.gaussian_process.components.fit_criterion import (
22+
_MLLForNonTLFitCriterionFactory,
23+
)
2124
from baybe.surrogates.gaussian_process.components.kernel import (
2225
_PureKernelFactory,
2326
)
@@ -68,10 +71,14 @@ def _make(
6871
)
6972

7073

74+
CHENFitCriterionFactory = _MLLForNonTLFitCriterionFactory()
75+
"""A factory providing fitting criteria for the CHEN preset."""
76+
7177
# Collect leftover original slotted classes processed by `attrs.define`
7278
gc.collect()
7379

74-
# Aliases for generic preset imports
75-
PresetKernelFactory = CHENKernelFactory
76-
PresetMeanFactory = LazyConstantMeanFactory
77-
PresetLikelihoodFactory = LazyGaussianLikelihoodFactory
80+
# Preset defaults
81+
KERNEL_FACTORY = CHENKernelFactory()
82+
MEAN_FACTORY = LazyConstantMeanFactory()
83+
LIKELIHOOD_FACTORY = LazyGaussianLikelihoodFactory()
84+
FIT_CRITERION_FACTORY = CHENFitCriterionFactory

0 commit comments

Comments
 (0)