Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 111 additions & 2 deletions baybe/acquisition/_builder.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to do the following:

I don't know when you branched this branch off of main, but we only recently added the AGENTS.md files (#769) which auto-inject instructions to achieve consistent code using agentic development. So if you haven't done that yet, please rebase this PR on main immediately and ask the agent to "replay" the commits with the new rules from AGENTS.md files in mind, the force push

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_ExpectedHypervolumeImprovement,
qExpectedHypervolumeImprovement,
qLogExpectedHypervolumeImprovement,
qLogNoisyExpectedImprovement,
qNegIntegratedPosteriorVariance,
qThompsonSampling,
)
Expand Down Expand Up @@ -75,6 +76,7 @@ class BotorchAcquisitionArgs:
# Optional, depending on the specific acquisition function being used
best_f: float | None = _OPT_FIELD
beta: float | None = _OPT_FIELD
constraints: list | None = _OPT_FIELD
maximize: bool | None = _OPT_FIELD
mc_points: Tensor | None = _OPT_FIELD
num_fantasies: int | None = _OPT_FIELD
Expand Down Expand Up @@ -197,6 +199,7 @@ def build(self) -> BoAcquisitionFunction:
# Set context-specific parameters
self._set_best_f()
self._set_target_transformation()
self._set_constraints()
self._set_X_baseline()
self._set_X_pending()
self._set_mc_points()
Expand All @@ -222,6 +225,18 @@ def _set_target_transformation(self) -> None:
return

if self.acqf.is_analytic:
# TODO: Certain analytic acquisition functions (e.g. analytic EI with
# constraints) do support outcome constraints and will be added to BayBE
# in the future. Once available, this guard should be scoped to only
# those analytic acqfs that do NOT support constraints, and
# `to_botorch_posterior_transform()` must be fixed to pad the weight
# vector to length `n_models`.
if self.objective.outcome_constraints:
raise IncompatibilityError(
f"Analytical acquisition function '{type(self.acqf).__name__}' "
f"does not support outcome constraints. Use an MC-based "
f"acquisition function instead."
)
try:
transform = self.objective.to_botorch_posterior_transform()
except NonGaussianityError as ex:
Expand Down Expand Up @@ -253,17 +268,111 @@ def _set_target_transformation(self) -> None:

self._args.objective = self.objective.to_botorch()

def _set_constraints(self) -> None:
"""Set BoTorch's ``constraints`` argument from outcome constraints.

Outcome constraint compatibility check — Layer 2 (acquisition function level).
Raises IncompatibilityError if the acqf's BoTorch __init__ signature does not
include a ``constraints`` parameter.
"""
if not self.objective.outcome_constraints:
return

if flds.constraints.name not in self._signature:
raise IncompatibilityError(
f"The selected acquisition function "
f"'{type(self.acqf).__name__}' does not support outcome "
f"constraints. Use a compatible acquisition function such as "
f"'{qLogNoisyExpectedImprovement.__name__}' instead."
)
constraints = self.objective.to_botorch_constraints()
if constraints:
self._args.constraints = constraints

def _set_best_f(self) -> None:
"""Set BoTorch's ``best_f`` argument."""
"""Set BoTorch's ``best_f`` argument.

best_f is a constant reference value (not differentiable). When outcome
constraints are present, only feasible training points are considered.
"""
if flds.best_f.name not in self._signature:
return

match self.objective:
case SingleTargetObjective() | DesirabilityObjective():
self._args.best_f = self._posterior_mean_comp.max().item()
if not (constraints := self.objective.to_botorch_constraints()):
self._args.best_f = self._posterior_mean_comp.max().item()
else:
self._args.best_f = self._compute_best_f_with_constraints(
constraints
)
case _:
raise NotImplementedError("This line should be impossible to reach.")

def _compute_best_f_with_constraints(
self, constraints: list[Callable[[Tensor], Tensor]]
) -> float:
"""Compute the best objective value considering outcome constraints.

Falls back to the global maximum if no feasible training point exists.

Args:
constraints: Constraint functions from
:meth:`~baybe.objectives.base.Objective.to_botorch_constraints`.

Returns:
The best feasible objective value, or the global maximum as fallback.
"""
# Get objective values for all training points
objective_values = self._posterior_mean_comp

# Get raw model predictions for constraint evaluation
batched = to_tensor(self._train_x).unsqueeze(-2)
posterior = self._botorch_surrogate.posterior(batched)
model_predictions = posterior.mean.squeeze(-2)

# Apply constraint functions to filter feasible points
feasible_mask = self._compute_feasible_mask(model_predictions, constraints)

if not feasible_mask.any():
# TODO: other mechanisms, e.g. steer towards feasible region?
# No feasible training points - fall back to global maximum
return objective_values.max().item()

# Return maximum among feasible points
feasible_objectives = objective_values[feasible_mask]
return feasible_objectives.max().item()

def _compute_feasible_mask(
self,
model_predictions: Tensor,
constraints: list[Callable[[Tensor], Tensor]],
) -> Tensor:
"""Compute boolean mask indicating which points satisfy all constraints.

Uses hard thresholding (feasible when constraint value <= 0) combined
via boolean AND across all constraints.

Args:
model_predictions: Raw model predictions [n_points, n_outputs]
constraints: Constraint functions from to_botorch_constraints()

Returns:
Boolean mask [n_points] where True = feasible, False = infeasible
"""
n_points = model_predictions.shape[0]
feasible_mask = torch.ones(n_points, dtype=torch.bool)

for constraint_func in constraints:
# Constraint func: [batch, q, m] -> [batch, q]; we insert q=1
# via unsqueeze(-2), so output is [n_points, 1]; squeeze q dim.
constraint_violations = constraint_func(
model_predictions.unsqueeze(-2)
).squeeze(-1)
feasible_mask &= constraint_violations <= 0

return feasible_mask
Comment on lines +363 to +374

def set_default_sample_shape(self, acqf: BoAcquisitionFunction, /):
"""Apply temporary workaround for Thompson sampling."""
# TODO: Needs redesign once bandits are supported more generally
Expand Down
3 changes: 3 additions & 0 deletions baybe/constraints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DiscreteProductConstraint,
DiscreteSumConstraint,
)
from baybe.constraints.outcome import OutcomeConstraint
from baybe.constraints.validation import validate_constraints

__all__ = [
Expand All @@ -42,6 +43,8 @@
"DiscretePermutationInvarianceConstraint",
"DiscreteProductConstraint",
"DiscreteSumConstraint",
# --- Outcome constraints ---#
"OutcomeConstraint",
# --- Other --- #
"validate_constraints",
"DISCRETE_CONSTRAINTS_FILTERING_ORDER",
Expand Down
89 changes: 89 additions & 0 deletions baybe/constraints/outcome.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Functionality for outcome constraints."""

from __future__ import annotations

from collections.abc import Callable
from typing import Literal

import pandas as pd
import torch
from attrs import define, field
from attrs.validators import in_, instance_of

from baybe.serialization.mixin import SerialMixin
from baybe.targets.base import Target


@define(frozen=True, slots=False)
class OutcomeConstraint(SerialMixin):
"""A constraint applied to target outcomes in the output space.

Outcome constraints restrict the feasible region based on target predictions,
different from parameter constraints which restrict the input space.
"""

target: Target = field(validator=instance_of(Target))
"""The target to be constrained."""

operator: Literal["<=", ">=", "=="] = field(validator=in_(["<=", ">=", "=="]))
"""The constraint operator."""

threshold: float = field(validator=instance_of((int, float)), converter=float)
"""The constraint threshold value in experimental units."""

Comment on lines +25 to +33
def __str__(self) -> str:
"""Return string representation."""
return f"{self.target.name} {self.operator} {self.threshold}"

def get_computational_threshold(self) -> float:
"""Convert experimental threshold to computational units.

Returns:
The threshold value in computational units.
"""
# Create dummy series with threshold value in experimental units
experimental_series = pd.Series([self.threshold], name=self.target.name)

# Apply the same transformations as the target
computational_series = self.target.transform(experimental_series)

return computational_series.iloc[0]

def to_botorch_constraint_func(
self, target_idx: int
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Create a botorch-compatible constraint function.

Args:
target_idx: Index of the target in model output.

Returns:
A constraint function that returns <= 0 for feasible region.
"""
computational_threshold = self.get_computational_threshold()

def constraint_func(samples: torch.Tensor) -> torch.Tensor:
"""Constraint function operating on computational level.

Args:
samples: Model output samples in computational units.

Returns:
Constraint values where <= 0 indicates feasible region.

Raises:
ValueError: If the constraint operator is not supported.
"""
if self.operator == "<=":
return samples[..., target_idx] - computational_threshold
elif self.operator == ">=":
return computational_threshold - samples[..., target_idx]
elif self.operator == "==":
# Equality constraint with small tolerance
return (
torch.abs(samples[..., target_idx] - computational_threshold) - 1e-6
)
else:
raise ValueError(f"Unsupported constraint operator: {self.operator}")

return constraint_func
4 changes: 4 additions & 0 deletions baybe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class MinimumCardinalityViolatedWarning(UserWarning):
"""Minimum cardinality constraints are violated."""


class OutcomeConstraintIgnoredWarning(UserWarning):
"""Outcome constraints are present but cannot be enforced by the recommender."""


##### Exceptions #####


Expand Down
Loading
Loading