Skip to content

Commit 7ff993d

Browse files
committed
Rename Criterion to FitCriterion
1 parent 79f3604 commit 7ff993d

11 files changed

Lines changed: 74 additions & 68 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
- Gaussian process component factories
1616
- Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process
1717
components, enabling full low-level customization
18-
- Configurable optimization criterion for Gaussian process hyperparameter selection
18+
- Configurable fitting criterion for Gaussian process hyperparameter optimization
1919
- Factories for all Gaussian process components
2020
- `CHEN`, `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
2121
- `TypeSelector` and `NameSelector` classes for parameter selection in kernel factories

baybe/surrogates/gaussian_process/components/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Gaussian process surrogate components."""
22

33
from baybe.surrogates.gaussian_process.components.criterion import (
4-
Criterion,
5-
CriterionFactoryProtocol,
6-
PlainCriterionFactory,
4+
FitCriterion,
5+
FitCriterionFactoryProtocol,
6+
PlainFitCriterionFactory,
77
)
88
from baybe.surrogates.gaussian_process.components.kernel import (
99
KernelFactoryProtocol,
@@ -21,9 +21,9 @@
2121

2222
__all__ = [
2323
# Criterion
24-
"Criterion",
25-
"CriterionFactoryProtocol",
26-
"PlainCriterionFactory",
24+
"FitCriterion",
25+
"FitCriterionFactoryProtocol",
26+
"PlainFitCriterionFactory",
2727
# Kernel
2828
"KernelFactoryProtocol",
2929
"PlainKernelFactory",

baybe/surrogates/gaussian_process/components/criterion.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Optimization criteria for the Gaussian process surrogate."""
1+
"""Fitting criteria for the Gaussian process surrogate."""
22

33
from __future__ import annotations
44

@@ -16,8 +16,8 @@
1616
from gpytorch.models import GP as GPyTorchModel
1717

1818

19-
class Criterion(Enum):
20-
"""Available optimization criteria for GP hyperparameter selection."""
19+
class FitCriterion(Enum):
20+
"""Available fitting criteria for GP hyperparameter optimization."""
2121

2222
MARGINAL_LOG_LIKELIHOOD = "MARGINAL_LOG_LIKELIHOOD"
2323
"""Exact marginal log-likelihood."""
@@ -32,14 +32,14 @@ def to_gpytorch(
3232
import gpytorch
3333

3434
mll_class = {
35-
Criterion.MARGINAL_LOG_LIKELIHOOD: gpytorch.ExactMarginalLogLikelihood,
36-
Criterion.LEAVE_ONE_OUT: gpytorch.mlls.LeaveOneOutPseudoLikelihood,
35+
FitCriterion.MARGINAL_LOG_LIKELIHOOD: gpytorch.ExactMarginalLogLikelihood,
36+
FitCriterion.LEAVE_ONE_OUT: gpytorch.mlls.LeaveOneOutPseudoLikelihood,
3737
}[self]
3838
return mll_class(likelihood, model)
3939

4040

41-
CriterionFactoryProtocol = GPComponentFactoryProtocol[Criterion]
42-
"""A protocol defining the interface for criterion factories."""
41+
FitCriterionFactoryProtocol = GPComponentFactoryProtocol[FitCriterion]
42+
"""A protocol defining the interface for fit criterion factories."""
4343

44-
PlainCriterionFactory = PlainGPComponentFactory[Criterion]
45-
"""A trivial factory that returns a fixed criterion."""
44+
PlainFitCriterionFactory = PlainGPComponentFactory[FitCriterion]
45+
"""A trivial factory that returns a fixed fit criterion."""

baybe/surrogates/gaussian_process/components/generic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
from gpytorch.means import Mean as GPyTorchMean
2424
from torch import Tensor
2525

26-
from baybe.surrogates.gaussian_process.components.criterion import Criterion
26+
from baybe.surrogates.gaussian_process.components.criterion import FitCriterion
2727

2828
GPyTorchGPComponent: TypeAlias = GPyTorchKernel | GPyTorchMean | GPyTorchLikelihood
29-
GPComponent: TypeAlias = BayBEGPComponent | GPyTorchGPComponent | Criterion
29+
GPComponent: TypeAlias = BayBEGPComponent | GPyTorchGPComponent | FitCriterion
3030
else:
3131
# At runtime, we use only the BayBE types for serialization compatibility
3232
GPComponent: TypeAlias = BayBEGPComponent
@@ -47,7 +47,7 @@ class GPComponentType(Enum):
4747
"""Gaussian process likelihood."""
4848

4949
CRITERION = "CRITERION"
50-
"""Gaussian process optimization criterion."""
50+
"""Gaussian process fitting criterion."""
5151

5252
def get_types(self) -> tuple[type, ...]:
5353
"""Get the accepted BayBE and GPyTorch types for this component."""
@@ -60,10 +60,10 @@ def get_types(self) -> tuple[type, ...]:
6060
types.append(Kernel)
6161
elif self is GPComponentType.CRITERION:
6262
from baybe.surrogates.gaussian_process.components.criterion import (
63-
Criterion,
63+
FitCriterion,
6464
)
6565

66-
types.append(Criterion)
66+
types.append(FitCriterion)
6767

6868
# Add GPyTorch type if available
6969
if sys.modules.get("gpytorch") is not None:

baybe/surrogates/gaussian_process/core.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from baybe.searchspace.core import SearchSpace
2121
from baybe.surrogates.base import Surrogate
2222
from baybe.surrogates.gaussian_process.components.criterion import (
23-
Criterion,
24-
CriterionFactoryProtocol,
23+
FitCriterion,
24+
FitCriterionFactoryProtocol,
2525
)
2626
from baybe.surrogates.gaussian_process.components.generic import (
2727
GPComponentType,
@@ -39,7 +39,7 @@
3939
GaussianProcessPreset,
4040
)
4141
from baybe.surrogates.gaussian_process.presets.baybe import (
42-
BayBECriterionFactory,
42+
BayBEFitCriterionFactory,
4343
BayBEKernelFactory,
4444
BayBELikelihoodFactory,
4545
BayBEMeanFactory,
@@ -183,19 +183,19 @@ class GaussianProcessSurrogate(Surrogate):
183183
* :class:`gpytorch.likelihoods.Likelihood`
184184
"""
185185

186-
criterion_factory: CriterionFactoryProtocol = field(
186+
criterion_factory: FitCriterionFactoryProtocol = field(
187187
alias="criterion_or_factory",
188-
factory=BayBECriterionFactory,
188+
factory=BayBEFitCriterionFactory,
189189
converter=partial( # type: ignore[misc]
190190
to_component_factory, component_type=GPComponentType.CRITERION
191191
),
192192
validator=is_callable(),
193193
)
194-
"""The optimization criterion for hyperparameter selection.
194+
"""The fitting criterion for Gaussian process hyperparameter optimization.
195195
196196
Accepts:
197-
* :class:`.components.criterion.Criterion`
198-
* :class:`.components.criterion.CriterionFactoryProtocol`
197+
* :class:`.components.criterion.FitCriterion`
198+
* :class:`.components.criterion.FitCriterionFactoryProtocol`
199199
"""
200200

201201
# TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently
@@ -215,7 +215,7 @@ def from_preset(
215215
likelihood_or_factory: LikelihoodFactoryProtocol
216216
| GPyTorchLikelihood
217217
| None = None,
218-
criterion_or_factory: Criterion | CriterionFactoryProtocol | None = None,
218+
criterion_or_factory: FitCriterion | FitCriterionFactoryProtocol | None = None,
219219
) -> Self:
220220
"""Create a Gaussian process surrogate from one of the defined presets."""
221221
preset = GaussianProcessPreset(preset)
@@ -230,7 +230,9 @@ def from_preset(
230230
likelihood = likelihood_or_factory or getattr(
231231
module, "PRESET_LIKELIHOOD_FACTORY"
232232
)
233-
criterion = criterion_or_factory or getattr(module, "PRESET_CRITERION_FACTORY")
233+
criterion = criterion_or_factory or getattr(
234+
module, "PRESET_FIT_CRITERION_FACTORY"
235+
)
234236

235237
return cls(kernel, mean, likelihood, criterion)
236238

@@ -324,7 +326,9 @@ def __str__(self) -> str:
324326
to_string("Kernel factory", self.kernel_factory, single_line=True),
325327
to_string("Mean factory", self.mean_factory, single_line=True),
326328
to_string("Likelihood factory", self.likelihood_factory, single_line=True),
327-
to_string("Criterion factory", self.criterion_factory, single_line=True),
329+
to_string(
330+
"Fit criterion factory", self.criterion_factory, single_line=True
331+
),
328332
]
329333
return to_string(super().__str__(), *fields)
330334

baybe/surrogates/gaussian_process/presets/__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
"""Gaussian process surrogate presets."""
22

33
# Criterion
4-
from baybe.surrogates.gaussian_process.components.criterion import Criterion
4+
from baybe.surrogates.gaussian_process.components.criterion import FitCriterion
55

66
# Default preset
77
from baybe.surrogates.gaussian_process.presets.baybe import (
8-
BayBECriterionFactory,
8+
BayBEFitCriterionFactory,
99
BayBEKernelFactory,
1010
BayBELikelihoodFactory,
1111
BayBEMeanFactory,
1212
)
1313

1414
# Chen preset
1515
from baybe.surrogates.gaussian_process.presets.chen import (
16-
CHENCriterionFactory,
16+
CHENFitCriterionFactory,
1717
CHENKernelFactory,
1818
)
1919

@@ -22,39 +22,39 @@
2222

2323
# EDBO preset
2424
from baybe.surrogates.gaussian_process.presets.edbo import (
25-
EDBOCriterionFactory,
25+
EDBOFitCriterionFactory,
2626
EDBOKernelFactory,
2727
EDBOLikelihoodFactory,
2828
EDBOMeanFactory,
2929
)
3030

3131
# Smoothed EDBO preset
3232
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
33-
SmoothedEDBOCriterionFactory,
33+
SmoothedEDBOFitCriterionFactory,
3434
SmoothedEDBOKernelFactory,
3535
SmoothedEDBOLikelihoodFactory,
3636
SmoothedEDBOMeanFactory,
3737
)
3838

3939
__all__ = [
4040
# Core
41-
"Criterion",
41+
"FitCriterion",
4242
"GaussianProcessPreset",
4343
# Default BayBE preset
44-
"BayBECriterionFactory",
44+
"BayBEFitCriterionFactory",
4545
"BayBEKernelFactory",
4646
"BayBELikelihoodFactory",
4747
"BayBEMeanFactory",
4848
# Chen preset
49-
"CHENCriterionFactory",
49+
"CHENFitCriterionFactory",
5050
"CHENKernelFactory",
5151
# EDBO preset
52-
"EDBOCriterionFactory",
52+
"EDBOFitCriterionFactory",
5353
"EDBOKernelFactory",
5454
"EDBOLikelihoodFactory",
5555
"EDBOMeanFactory",
5656
# Smoothed EDBO preset
57-
"SmoothedEDBOCriterionFactory",
57+
"SmoothedEDBOFitCriterionFactory",
5858
"SmoothedEDBOKernelFactory",
5959
"SmoothedEDBOLikelihoodFactory",
6060
"SmoothedEDBOMeanFactory",

baybe/surrogates/gaussian_process/presets/baybe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
)
1919
from baybe.searchspace.core import SearchSpace
2020
from baybe.surrogates.gaussian_process.components.criterion import (
21-
Criterion,
22-
CriterionFactoryProtocol,
21+
FitCriterion,
22+
FitCriterionFactoryProtocol,
2323
)
2424
from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory
2525
from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory
@@ -91,22 +91,22 @@ def _make(
9191

9292

9393
@define
94-
class BayBECriterionFactory(CriterionFactoryProtocol):
95-
"""The factory providing the default optimization criterion for Gaussian process surrogates.""" # noqa: E501
94+
class BayBEFitCriterionFactory(FitCriterionFactoryProtocol):
95+
"""The factory providing the default fitting criterion for Gaussian process surrogates.""" # noqa: E501
9696

9797
@override
9898
def __call__(
9999
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
100-
) -> Criterion:
100+
) -> FitCriterion:
101101
return (
102-
Criterion.MARGINAL_LOG_LIKELIHOOD
102+
FitCriterion.MARGINAL_LOG_LIKELIHOOD
103103
if searchspace.task_idx is None
104-
else Criterion.LEAVE_ONE_OUT
104+
else FitCriterion.LEAVE_ONE_OUT
105105
)
106106

107107

108108
# Preset defaults
109109
PRESET_KERNEL_FACTORY = BayBEKernelFactory()
110110
PRESET_MEAN_FACTORY = BayBEMeanFactory()
111111
PRESET_LIKELIHOOD_FACTORY = BayBELikelihoodFactory()
112-
PRESET_CRITERION_FACTORY = BayBECriterionFactory()
112+
PRESET_FIT_CRITERION_FACTORY = BayBEFitCriterionFactory()

baybe/surrogates/gaussian_process/presets/chen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
)
2020
from baybe.priors.basic import GammaPrior
2121
from baybe.surrogates.gaussian_process.components.criterion import (
22-
Criterion,
23-
PlainCriterionFactory,
22+
FitCriterion,
23+
PlainFitCriterionFactory,
2424
)
2525
from baybe.surrogates.gaussian_process.components.kernel import (
2626
_PureKernelFactory,
@@ -72,8 +72,8 @@ def _make(
7272
)
7373

7474

75-
CHENCriterionFactory = PlainCriterionFactory(Criterion.MARGINAL_LOG_LIKELIHOOD)
76-
"""A factory providing optimization criteria for the CHEN preset."""
75+
CHENFitCriterionFactory = PlainFitCriterionFactory(FitCriterion.MARGINAL_LOG_LIKELIHOOD)
76+
"""A factory providing fitting criteria for the CHEN preset."""
7777

7878
# Collect leftover original slotted classes processed by `attrs.define`
7979
gc.collect()
@@ -82,4 +82,4 @@ def _make(
8282
PRESET_KERNEL_FACTORY = CHENKernelFactory()
8383
PRESET_MEAN_FACTORY = LazyConstantMeanFactory()
8484
PRESET_LIKELIHOOD_FACTORY = LazyGaussianLikelihoodFactory()
85-
PRESET_CRITERION_FACTORY = CHENCriterionFactory
85+
PRESET_FIT_CRITERION_FACTORY = CHENFitCriterionFactory

baybe/surrogates/gaussian_process/presets/edbo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from baybe.priors.basic import GammaPrior
2323
from baybe.searchspace.discrete import SubspaceDiscrete
2424
from baybe.surrogates.gaussian_process.components.criterion import (
25-
Criterion,
26-
PlainCriterionFactory,
25+
FitCriterion,
26+
PlainFitCriterionFactory,
2727
)
2828
from baybe.surrogates.gaussian_process.components.kernel import (
2929
_PureKernelFactory,
@@ -180,13 +180,13 @@ def __call__(
180180

181181

182182
# Collect leftover original slotted classes processed by `attrs.define`
183-
EDBOCriterionFactory = PlainCriterionFactory(Criterion.MARGINAL_LOG_LIKELIHOOD)
184-
"""A factory providing optimization criteria for the EDBO preset."""
183+
EDBOFitCriterionFactory = PlainFitCriterionFactory(FitCriterion.MARGINAL_LOG_LIKELIHOOD)
184+
"""A factory providing fitting criteria for the EDBO preset."""
185185

186186
gc.collect()
187187

188188
# Preset defaults
189189
PRESET_KERNEL_FACTORY = EDBOKernelFactory()
190190
PRESET_MEAN_FACTORY = EDBOMeanFactory()
191191
PRESET_LIKELIHOOD_FACTORY = EDBOLikelihoodFactory()
192-
PRESET_CRITERION_FACTORY = EDBOCriterionFactory
192+
PRESET_FIT_CRITERION_FACTORY = EDBOFitCriterionFactory

baybe/surrogates/gaussian_process/presets/edbo_smoothed.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
)
2020
from baybe.priors.basic import GammaPrior
2121
from baybe.surrogates.gaussian_process.components.criterion import (
22-
Criterion,
23-
PlainCriterionFactory,
22+
FitCriterion,
23+
PlainFitCriterionFactory,
2424
)
2525
from baybe.surrogates.gaussian_process.components.kernel import (
2626
_PureKernelFactory,
@@ -130,8 +130,10 @@ def __call__(
130130
return likelihood
131131

132132

133-
SmoothedEDBOCriterionFactory = PlainCriterionFactory(Criterion.MARGINAL_LOG_LIKELIHOOD)
134-
"""A factory providing optimization criteria for the smoothed EDBO preset."""
133+
SmoothedEDBOFitCriterionFactory = PlainFitCriterionFactory(
134+
FitCriterion.MARGINAL_LOG_LIKELIHOOD
135+
)
136+
"""A factory providing fitting criteria for the smoothed EDBO preset."""
135137

136138
# Collect leftover original slotted classes processed by `attrs.define`
137139
gc.collect()
@@ -140,4 +142,4 @@ def __call__(
140142
PRESET_KERNEL_FACTORY = SmoothedEDBOKernelFactory()
141143
PRESET_MEAN_FACTORY = SmoothedEDBOMeanFactory()
142144
PRESET_LIKELIHOOD_FACTORY = SmoothedEDBOLikelihoodFactory()
143-
PRESET_CRITERION_FACTORY = SmoothedEDBOCriterionFactory
145+
PRESET_FIT_CRITERION_FACTORY = SmoothedEDBOFitCriterionFactory

0 commit comments

Comments
 (0)