1919from baybe .parameters .categorical import TaskParameter
2020from baybe .searchspace .core import SearchSpace
2121from baybe .surrogates .base import Surrogate
22+ from baybe .surrogates .gaussian_process .components .fit_criterion import (
23+ FitCriterion ,
24+ FitCriterionFactoryProtocol ,
25+ )
2226from baybe .surrogates .gaussian_process .components .generic import (
2327 GPComponentType ,
2428 to_component_factory ,
3539 GaussianProcessPreset ,
3640)
3741from baybe .surrogates .gaussian_process .presets .baybe import (
42+ BayBEFitCriterionFactory ,
3843 BayBEKernelFactory ,
3944 BayBELikelihoodFactory ,
4045 BayBEMeanFactory ,
@@ -178,6 +183,21 @@ class GaussianProcessSurrogate(Surrogate):
178183 * :class:`gpytorch.likelihoods.Likelihood`
179184 """
180185
186+ criterion_factory : FitCriterionFactoryProtocol = field (
187+ alias = "criterion_or_factory" ,
188+ factory = BayBEFitCriterionFactory ,
189+ converter = partial ( # type: ignore[misc]
190+ to_component_factory , component_type = GPComponentType .CRITERION
191+ ),
192+ validator = is_callable (),
193+ )
194+ """The fitting criterion for Gaussian process hyperparameter optimization.
195+
196+ Accepts:
197+ * :class:`.components.fit_criterion.FitCriterion`
198+ * :class:`.components.fit_criterion.FitCriterionFactoryProtocol`
199+ """
200+
181201 # TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently
182202 # omitted due to: https://github.com/python-attrs/cattrs/issues/531
183203 _model = field (init = False , default = None , eq = False )
@@ -195,6 +215,7 @@ def from_preset(
195215 likelihood_or_factory : LikelihoodFactoryProtocol
196216 | GPyTorchLikelihood
197217 | None = None ,
218+ criterion_or_factory : FitCriterion | FitCriterionFactoryProtocol | None = None ,
198219 ) -> Self :
199220 """Create a Gaussian process surrogate from one of the defined presets."""
200221 preset = GaussianProcessPreset (preset )
@@ -204,13 +225,12 @@ def from_preset(
204225 )
205226 module = importlib .import_module (module_name )
206227
207- kernel = kernel_or_factory or getattr (module , "PresetKernelFactory" )()
208- mean = mean_or_factory or getattr (module , "PresetMeanFactory" )()
209- likelihood = (
210- likelihood_or_factory or getattr (module , "PresetLikelihoodFactory" )()
211- )
228+ kernel = kernel_or_factory or getattr (module , "KERNEL_FACTORY" )
229+ mean = mean_or_factory or getattr (module , "MEAN_FACTORY" )
230+ likelihood = likelihood_or_factory or getattr (module , "LIKELIHOOD_FACTORY" )
231+ criterion = criterion_or_factory or getattr (module , "FIT_CRITERION_FACTORY" )
212232
213- return cls (kernel , mean , likelihood )
233+ return cls (kernel , mean , likelihood , criterion )
214234
215235 @override
216236 def to_botorch (self ) -> GPyTorchModel :
@@ -237,7 +257,6 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior:
237257 @override
238258 def _fit (self , train_x : Tensor , train_y : Tensor ) -> None :
239259 import botorch
240- import gpytorch
241260 from botorch .models .transforms import Normalize , Standardize
242261
243262 assert self ._searchspace is not None # provided by base class
@@ -281,6 +300,9 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
281300 ### Likelihood
282301 likelihood = self .likelihood_factory (context .searchspace , train_x , train_y )
283302
303+ ### Criterion
304+ criterion = self .criterion_factory (context .searchspace , train_x , train_y )
305+
284306 ### Model construction and fitting
285307 self ._model = botorch .models .SingleTaskGP (
286308 train_x ,
@@ -291,18 +313,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
291313 covar_module = kernel ,
292314 likelihood = likelihood ,
293315 )
294-
295- # TODO: This is still a temporary workaround to avoid overfitting seen in
296- # low-dimensional TL cases. More robust settings are being researched.
297- if context .n_task_dimensions > 0 :
298- mll = gpytorch .mlls .LeaveOneOutPseudoLikelihood (
299- self ._model .likelihood , self ._model
300- )
301- else :
302- mll = gpytorch .ExactMarginalLogLikelihood (
303- self ._model .likelihood , self ._model
304- )
305-
316+ mll = criterion .to_gpytorch (self ._model .likelihood , self ._model )
306317 botorch .fit .fit_gpytorch_mll (mll )
307318
308319 @override
@@ -311,6 +322,9 @@ def __str__(self) -> str:
311322 to_string ("Kernel factory" , self .kernel_factory , single_line = True ),
312323 to_string ("Mean factory" , self .mean_factory , single_line = True ),
313324 to_string ("Likelihood factory" , self .likelihood_factory , single_line = True ),
325+ to_string (
326+ "Fit criterion factory" , self .criterion_factory , single_line = True
327+ ),
314328 ]
315329 return to_string (super ().__str__ (), * fields )
316330
0 commit comments