1- """BayBE two-stage acquisition functions ."""
1+ """Custom Botorch AnalyticAcquisitionFunction for multi-fidelity optimization ."""
22
33from __future__ import annotations
44
5+ from collections .abc import Callable , Mapping
56from itertools import pairwise as iter_pairwise
67from itertools import product as iter_product
7- from typing import cast
8+ from typing import Any
89
910import torch
10- from attrs import define , field
11+ from attrs import Attribute , define , field , fields_dict
1112from attrs .validators import deep_iterable , deep_mapping , ge , instance_of , or_
1213from botorch .acquisition .analytic import AnalyticAcquisitionFunction
1314from botorch .acquisition .objective import PosteriorTransform
2122from typing_extensions import override
2223
2324from baybe .parameters .validation import validate_contains_exactly_one
24- from baybe .utils .validation import finite_float , validate_dict_shape
25+ from baybe .utils .validation import finite_float
2526
2627_neg_inv_sqrt2 = - 0.7071067811865476
2728_log_sqrt_pi_div_2 = 0.2257913526447274
2829
2930
31+ def validate_dict_shape (
32+ reference_name : str , /
33+ ) -> Callable [[Any , Attribute , Mapping [Any , Any ]], None ]:
34+ """Make validator to check attribute keys/lengths against a reference attribute."""
35+
36+ def validator (obj : Any , attribute : Attribute , value : Mapping [Any , Any ]) -> None : # noqa: DOC101, DOC103
37+ """Validate that the input has the same keys/lengths as the reference attribute.
38+
39+ Raises:
40+ ValueError: If the keys of the two attributes mismatch.
41+ ValueError: If the tuple lengths of the two attributes mismatch at any key.
42+ """
43+ other_attr = fields_dict (type (obj ))[reference_name ]
44+ other_instance = getattr (obj , reference_name )
45+
46+ if not (
47+ different_keys := set (value .keys ()).symmetric_difference (
48+ set (other_instance .keys ())
49+ )
50+ ):
51+ raise ValueError (
52+ f"{ attribute .name } and { other_attr .alias } differ in keys in "
53+ f"{ obj .name } , with the following { different_keys } in only one."
54+ )
55+
56+ for k , tup in value .items ():
57+ other_tup = other_instance [k ]
58+
59+ if len (tup ) != len (other_tup ):
60+ raise ValueError (
61+ f"The lengths of the attributes '{ other_attr .alias } ' and "
62+ f"'{ attribute .alias } ' do not match for '{ obj .name } ' at the key { k } ."
63+ f"Length of '{ other_attr .alias } ' at key { k } : { len (other_tup )} . "
64+ f"Length of '{ attribute .alias } ' at key { k } : { len (tup )} ."
65+ )
66+
67+ return validator
68+
69+
3070@define
3171class MultiFidelityUpperConfidenceBound (AnalyticAcquisitionFunction ):
3272 r"""Two-stage Multi Fidelity Upper Confidence Bound (UCB).
@@ -44,7 +84,7 @@ class MultiFidelityUpperConfidenceBound(AnalyticAcquisitionFunction):
4484
4585 # Declaring attribute types for variables defined via _register_buffer.
4686 fidelity_columns : Tensor
47- fidelities_comb : Tensor
87+ fidelity_combinations : Tensor
4888 zetas_comb : Tensor
4989 costs_comb : Tensor
5090
@@ -128,7 +168,7 @@ def __post_attrs_init__(self) -> None:
128168 )
129169
130170 self .register_buffer (
131- "fidelities_comb " ,
171+ "fidelity_combinations " ,
132172 torch .tensor (
133173 list (iter_product (* self .fidelities .values ())), dtype = torch .double
134174 ),
@@ -161,10 +201,10 @@ def forward(self, X: Tensor) -> Tensor:
161201 """
162202 batch_size , q , d = X .shape
163203
164- n_comb , k = self .fidelities_comb .shape
204+ n_comb , k = self .fidelity_combinations .shape
165205
166206 X_extended = X .clone ().unsqueeze (1 ).repeat (1 , n_comb , 1 , 1 )
167- X_extended [..., :, self .fidelity_columns ] = self .fidelities_comb .view (
207+ X_extended [..., :, self .fidelity_columns ] = self .fidelity_combinations .view (
168208 1 , n_comb , 1 , k
169209 )
170210
@@ -201,20 +241,18 @@ def forward(self, X: Tensor) -> Tensor:
201241
202242 def optimize_stage_two (self , X : Tensor ) -> Tensor :
203243 r"""Second optimisation stage: choose optimal fidelity to query."""
204- # Jordan MHS NOTE: casting here because botorch model likelihood is too
205- # broadly typed. Check best practice in case likelihood does not have noise.
206- likelihood = cast (GaussianLikelihood , self .model .likelihood )
207-
208- # Possible TODO: consider heteroskedastic noise between fidelities.
209- aleatoric_uncertainty = torch .sqrt (likelihood .noise )
244+ if isinstance (self .model .likelihood , GaussianLikelihood ):
245+ aleatoric_uncertainty = torch .sqrt (self .model .likelihood .noise )
246+ else :
247+ aleatoric_uncertainty = torch .tensor (0.0 )
210248
211249 found_suitable_lower_fid = False
212250
213251 total_costs_comb = self .costs_comb .sum (dim = - 1 )
214252 increasing_cost_order = torch .argsort (total_costs_comb )
215253
216254 for prev_i , curr_i in iter_pairwise (increasing_cost_order ):
217- prev_fid = self .fidelities_comb [prev_i ].clone ()
255+ prev_fid = self .fidelity_combinations [prev_i ].clone ()
218256 prev_cost = self .costs_comb .sum (dim = - 1 )[prev_i ]
219257 curr_cost = self .costs_comb .sum (dim = - 1 )[curr_i ]
220258 prev_zeta = self .zetas_comb .sum (dim = - 1 )[prev_i ]
@@ -238,7 +276,7 @@ def optimize_stage_two(self, X: Tensor) -> Tensor:
238276
239277 if not found_suitable_lower_fid :
240278 optimal_X = X .clone ()
241- last_fid = self .fidelities_comb [curr_i ].clone ()
279+ last_fid = self .fidelity_combinations [curr_i ].clone ()
242280 optimal_X [:, self .fidelity_columns ] = last_fid
243281
244282 return optimal_X
0 commit comments