Skip to content

Commit d55fd74

Browse files
committed
Fall back to BayBE fit criterion for unspecified cases
1 parent 2c49cc8 commit d55fd74

4 files changed

Lines changed: 35 additions & 12 deletions

File tree

baybe/surrogates/gaussian_process/components/fit_criterion.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55
from enum import Enum
66
from typing import TYPE_CHECKING
77

8+
from attrs import define
9+
from typing_extensions import override
10+
811
if TYPE_CHECKING:
912
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
1013
from gpytorch.mlls import MarginalLogLikelihood
1114
from gpytorch.models import GP as GPyTorchModel
15+
from torch import Tensor
16+
17+
from baybe.searchspace.core import SearchSpace
1218

1319

1420
class FitCriterion(Enum):
@@ -44,3 +50,25 @@ def to_gpytorch(
4450

4551
PlainFitCriterionFactory = PlainGPComponentFactory[FitCriterion]
4652
"""A trivial factory that returns a fixed fit criterion."""
53+
54+
55+
@define
56+
class _MLLForNonTLFitCriterionFactory(FitCriterionFactoryProtocol):
57+
"""A fit criterion factory switching between MLL and BayBE depending on the context.
58+
59+
In transfer learning contexts, delegates to
60+
:class:`baybe.surrogates.gaussian_process.presets.baybe.BayBEFitCriterionFactory`.
61+
"""
62+
63+
@override
64+
def __call__(
65+
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
66+
) -> FitCriterion:
67+
if searchspace.task_idx is None:
68+
return FitCriterion.MARGINAL_LOG_LIKELIHOOD
69+
70+
from baybe.surrogates.gaussian_process.presets.baybe import (
71+
BayBEFitCriterionFactory,
72+
)
73+
74+
return BayBEFitCriterionFactory()(searchspace, train_x, train_y)

baybe/surrogates/gaussian_process/presets/chen.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
)
2020
from baybe.priors.basic import GammaPrior
2121
from baybe.surrogates.gaussian_process.components.fit_criterion import (
22-
FitCriterion,
23-
PlainFitCriterionFactory,
22+
_MLLForNonTLFitCriterionFactory,
2423
)
2524
from baybe.surrogates.gaussian_process.components.kernel import (
2625
_PureKernelFactory,
@@ -72,7 +71,7 @@ def _make(
7271
)
7372

7473

75-
CHENFitCriterionFactory = PlainFitCriterionFactory(FitCriterion.MARGINAL_LOG_LIKELIHOOD)
74+
CHENFitCriterionFactory = _MLLForNonTLFitCriterionFactory()
7675
"""A factory providing fitting criteria for the CHEN preset."""
7776

7877
# Collect leftover original slotted classes processed by `attrs.define`

baybe/surrogates/gaussian_process/presets/edbo.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
from baybe.priors.basic import GammaPrior
2323
from baybe.searchspace.discrete import SubspaceDiscrete
2424
from baybe.surrogates.gaussian_process.components.fit_criterion import (
25-
FitCriterion,
26-
PlainFitCriterionFactory,
25+
_MLLForNonTLFitCriterionFactory,
2726
)
2827
from baybe.surrogates.gaussian_process.components.kernel import (
2928
_PureKernelFactory,
@@ -179,10 +178,10 @@ def __call__(
179178
return likelihood
180179

181180

182-
# Collect leftover original slotted classes processed by `attrs.define`
183-
EDBOFitCriterionFactory = PlainFitCriterionFactory(FitCriterion.MARGINAL_LOG_LIKELIHOOD)
181+
EDBOFitCriterionFactory = _MLLForNonTLFitCriterionFactory()
184182
"""A factory providing fitting criteria for the EDBO preset."""
185183

184+
# Collect leftover original slotted classes processed by `attrs.define`
186185
gc.collect()
187186

188187
# Preset defaults

baybe/surrogates/gaussian_process/presets/edbo_smoothed.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
)
2020
from baybe.priors.basic import GammaPrior
2121
from baybe.surrogates.gaussian_process.components.fit_criterion import (
22-
FitCriterion,
23-
PlainFitCriterionFactory,
22+
_MLLForNonTLFitCriterionFactory,
2423
)
2524
from baybe.surrogates.gaussian_process.components.kernel import (
2625
_PureKernelFactory,
@@ -130,9 +129,7 @@ def __call__(
130129
return likelihood
131130

132131

133-
SmoothedEDBOFitCriterionFactory = PlainFitCriterionFactory(
134-
FitCriterion.MARGINAL_LOG_LIKELIHOOD
135-
)
132+
SmoothedEDBOFitCriterionFactory = _MLLForNonTLFitCriterionFactory()
136133
"""A factory providing fitting criteria for the smoothed EDBO preset."""
137134

138135
# Collect leftover original slotted classes processed by `attrs.define`

0 commit comments

Comments
 (0)