Skip to content

Commit bb7b334

Browse files
committed
Docstrings, error messages, variable names, file structure.
1 parent 951af19 commit bb7b334

4 files changed

Lines changed: 63 additions & 61 deletions

File tree

baybe/acquisition/_builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def build(self) -> BoAcquisitionFunction:
210210
self._set_ref_point()
211211
self._set_partitioning()
212212
self._set_current_value()
213-
self._set_project()
213+
self._set_projection()
214214
self._set_MFUCB_dicts()
215215

216216
botorch_acqf = self._botorch_acqf_cls(**self._args.collect())
@@ -320,13 +320,12 @@ def _set_current_value(self) -> None:
320320

321321
self._args.current_value = current_value
322322

323-
def _set_project(self) -> None:
323+
def _set_projection(self) -> None:
324324
"""Set projection to the target fidelity for qMFKG."""
325325
if not isinstance(self.acqf, (qMultiFidelityKnowledgeGradient)):
326326
return
327327

328-
# Jordan MHS TODO: check where fidelity--acqf compatibility logic should be.
329-
assert self.searchspace.fidelity_idx is not None, "Unreachable error."
328+
assert self.searchspace.fidelity_idx is not None # for mypy
330329

331330
target_fidelities = {self.searchspace.fidelity_idx: 1.0}
332331

baybe/acquisition/custom_acqfs/two_stage.py renamed to baybe/acquisition/custom_acqfs/mfucb.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
"""BayBE two-stage acquisition functions."""
1+
"""Custom Botorch AnalyticAcquisitionFunction for multi-fidelity optimization."""
22

33
from __future__ import annotations
44

5+
from collections.abc import Callable, Mapping
56
from itertools import pairwise as iter_pairwise
67
from itertools import product as iter_product
7-
from typing import cast
8+
from typing import Any
89

910
import torch
10-
from attrs import define, field
11+
from attrs import Attribute, define, field, fields_dict
1112
from attrs.validators import deep_iterable, deep_mapping, ge, instance_of, or_
1213
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
1314
from botorch.acquisition.objective import PosteriorTransform
@@ -21,12 +22,51 @@
2122
from typing_extensions import override
2223

2324
from baybe.parameters.validation import validate_contains_exactly_one
24-
from baybe.utils.validation import finite_float, validate_dict_shape
25+
from baybe.utils.validation import finite_float
2526

2627
_neg_inv_sqrt2 = -0.7071067811865476
2728
_log_sqrt_pi_div_2 = 0.2257913526447274
2829

2930

31+
def validate_dict_shape(
32+
reference_name: str, /
33+
) -> Callable[[Any, Attribute, Mapping[Any, Any]], None]:
34+
"""Make validator to check attribute keys/lengths against a reference attribute."""
35+
36+
def validator(obj: Any, attribute: Attribute, value: Mapping[Any, Any]) -> None: # noqa: DOC101, DOC103
37+
"""Validate that the input has the same keys/lengths as the reference attribute.
38+
39+
Raises:
40+
ValueError: If the keys of the two attributes mismatch.
41+
ValueError: If the tuple lengths of the two attributes mismatch at any key.
42+
"""
43+
other_attr = fields_dict(type(obj))[reference_name]
44+
other_instance = getattr(obj, reference_name)
45+
46+
if not (
47+
different_keys := set(value.keys()).symmetric_difference(
48+
set(other_instance.keys())
49+
)
50+
):
51+
raise ValueError(
52+
f"{attribute.name} and {other_attr.alias} differ in keys in "
53+
f"{obj.name}, with the following {different_keys} in only one."
54+
)
55+
56+
for k, tup in value.items():
57+
other_tup = other_instance[k]
58+
59+
if len(tup) != len(other_tup):
60+
raise ValueError(
61+
f"The lengths of the attributes '{other_attr.alias}' and "
62+
f"'{attribute.alias}' do not match for '{obj.name}' at the key {k}."
63+
f"Length of '{other_attr.alias}' at key {k}: {len(other_tup)}. "
64+
f"Length of '{attribute.alias}' at key {k}: {len(tup)}."
65+
)
66+
67+
return validator
68+
69+
3070
@define
3171
class MultiFidelityUpperConfidenceBound(AnalyticAcquisitionFunction):
3272
r"""Two-stage Multi Fidelity Upper Confidence Bound (UCB).
@@ -44,7 +84,7 @@ class MultiFidelityUpperConfidenceBound(AnalyticAcquisitionFunction):
4484

4585
# Declaring attribute types for variables defined via _register_buffer.
4686
fidelity_columns: Tensor
47-
fidelities_comb: Tensor
87+
fidelity_combinations: Tensor
4888
zetas_comb: Tensor
4989
costs_comb: Tensor
5090

@@ -128,7 +168,7 @@ def __post_attrs_init__(self) -> None:
128168
)
129169

130170
self.register_buffer(
131-
"fidelities_comb",
171+
"fidelity_combinations",
132172
torch.tensor(
133173
list(iter_product(*self.fidelities.values())), dtype=torch.double
134174
),
@@ -161,10 +201,10 @@ def forward(self, X: Tensor) -> Tensor:
161201
"""
162202
batch_size, q, d = X.shape
163203

164-
n_comb, k = self.fidelities_comb.shape
204+
n_comb, k = self.fidelity_combinations.shape
165205

166206
X_extended = X.clone().unsqueeze(1).repeat(1, n_comb, 1, 1)
167-
X_extended[..., :, self.fidelity_columns] = self.fidelities_comb.view(
207+
X_extended[..., :, self.fidelity_columns] = self.fidelity_combinations.view(
168208
1, n_comb, 1, k
169209
)
170210

@@ -201,20 +241,18 @@ def forward(self, X: Tensor) -> Tensor:
201241

202242
def optimize_stage_two(self, X: Tensor) -> Tensor:
203243
r"""Second optimisation stage: choose optimal fidelity to query."""
204-
# Jordan MHS NOTE: casting here because botorch model likelihood is too
205-
# broadly typed. Check best practice in case likelihood does not have noise.
206-
likelihood = cast(GaussianLikelihood, self.model.likelihood)
207-
208-
# Possible TODO: consider heteroskedastic noise between fidelities.
209-
aleatoric_uncertainty = torch.sqrt(likelihood.noise)
244+
if isinstance(self.model.likelihood, GaussianLikelihood):
245+
aleatoric_uncertainty = torch.sqrt(self.model.likelihood.noise)
246+
else:
247+
aleatoric_uncertainty = torch.tensor(0.0)
210248

211249
found_suitable_lower_fid = False
212250

213251
total_costs_comb = self.costs_comb.sum(dim=-1)
214252
increasing_cost_order = torch.argsort(total_costs_comb)
215253

216254
for prev_i, curr_i in iter_pairwise(increasing_cost_order):
217-
prev_fid = self.fidelities_comb[prev_i].clone()
255+
prev_fid = self.fidelity_combinations[prev_i].clone()
218256
prev_cost = self.costs_comb.sum(dim=-1)[prev_i]
219257
curr_cost = self.costs_comb.sum(dim=-1)[curr_i]
220258
prev_zeta = self.zetas_comb.sum(dim=-1)[prev_i]
@@ -238,7 +276,7 @@ def optimize_stage_two(self, X: Tensor) -> Tensor:
238276

239277
if not found_suitable_lower_fid:
240278
optimal_X = X.clone()
241-
last_fid = self.fidelities_comb[curr_i].clone()
279+
last_fid = self.fidelity_combinations[curr_i].clone()
242280
optimal_X[:, self.fidelity_columns] = last_fid
243281

244282
return optimal_X

baybe/parameters/fidelity.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,20 @@ def highest_fidelity(self) -> str:
110110
value for value, zeta in zip(self.values, self.zeta) if zeta == 0
111111
)
112112

113-
assert isinstance(highest_fid, str), "Error should be unreachable."
113+
assert isinstance(highest_fid, str) # for mypy
114114

115115
return highest_fid
116116

117117
@property
118118
def highest_fidelity_cost(self) -> int:
119119
"""Cost of querying the fidelity with discrepancy value of zero."""
120-
highest_fid = next(
120+
highest_cost = next(
121121
cost for cost, zeta in zip(self.costs, self.zeta) if zeta == 0
122122
)
123123

124-
assert isinstance(highest_fid, int), "Error should be unreachable."
124+
assert isinstance(highest_cost, int) # for mypy
125125

126-
return highest_fid
126+
return highest_cost
127127

128128
@override
129129
@cached_property

baybe/utils/validation.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from __future__ import annotations
44

55
import math
6-
from collections.abc import Callable, Iterable, Mapping
6+
from collections.abc import Callable, Iterable
77
from typing import TYPE_CHECKING, Any
88

99
import numpy as np
1010
import pandas as pd
11-
from attrs import Attribute, fields_dict
11+
from attrs import Attribute
1212

1313
from baybe.exceptions import IncompleteMeasurementsError
1414
from baybe.settings import active_settings
@@ -261,38 +261,3 @@ def preprocess_dataframe(
261261
else:
262262
targets = ()
263263
return normalize_input_dtypes(df, [*searchspace.parameters, *targets])
264-
265-
266-
def validate_dict_shape(
267-
reference_name: str, /
268-
) -> Callable[[Any, Attribute, Mapping[Any, Any]], None]:
269-
"""Make validator to check attribute keys/lengths against a reference attribute."""
270-
271-
def validator(obj: Any, attribute: Attribute, value: Mapping[Any, Any]) -> None: # noqa: DOC101, DOC103
272-
"""Validate that the input has the same keys/lengths as the reference attribute.
273-
274-
Raises:
275-
ValueError: If the keys of the two attributes mismatch.
276-
ValueError: If the tuple lengths of the two attributes mismatch at any key.
277-
"""
278-
other_attr = fields_dict(type(obj))[reference_name]
279-
other_instance = getattr(obj, reference_name)
280-
281-
if set(value.keys()) != set(other_instance.keys()):
282-
raise ValueError(
283-
f"{attribute.name} must have the same keys as {other_attr.alias} in "
284-
f"{obj.name}."
285-
)
286-
287-
for k, tup in value.items():
288-
other_tup = other_instance[k]
289-
290-
if len(tup) != len(other_tup):
291-
raise ValueError(
292-
f"The lengths of the attributes '{other_attr.alias}' and "
293-
f"'{attribute.alias}' do not match for '{obj.name}' at the key {k}."
294-
f"Length of '{other_attr.alias}' at key {k}: {len(other_tup)}. "
295-
f"Length of '{attribute.alias}' at key {k}: {len(tup)}."
296-
)
297-
298-
return validator

0 commit comments

Comments
 (0)