Skip to content

Commit c40a33f

Browse files
Lena Kashtelyanmeta-codesync[bot]
authored andcommitted
Rename EarlyStoppingStrategy → ArmStoppingStrategy with arm-level decisions
Summary: # Context The current `EarlyStoppingStrategy` was built for single-arm `Trial`s only — it returns `dict[int, str | None]` (trial index → reason) and explicitly rejects `BatchTrial`s. This diff extends it to support `BatchTrial`s by making stopping decisions at the arm level. But we are going to need batch-level stopping. # Key design decision: `ArmStoppingStrategy` or `TrialStoppingStrategy`? [in Ax TLs sync, we decided we liked `ArmStoppingStrategy` One choice we're going to need to make is whether we want arm-level and trial-level stopping strategies to be the same or separate. The use cases for arm stopping will have to do with safety, constraint violations etc. The use cases for stopping trials will have more to do with normal orchestration, e.g. we need to stop a trial in order to run a new one. # Next step: add `Runner.stop_arm` and `Runner.stop_trial` I think that the two will be separate bc they'll often entail different logic at their respective backends. How we choose to do these may impact how we choose to do the stopping strategy, too. # Another likely next step: make GS and ESS decisions jointly I think that the two will be separate bc they'll often entail different logic at their respective backends. How we choose to do these may impact how we choose to do the stopping strategy, too. I know not everyone loves this idea; let's discuss. Related discussion about a use case here: https://docs.google.com/document/d/19K3kBXX9c5WUIUu_t_gC9KZkeAMo_EffSymKjaPTSyI/edit?tab=t.0, re-pasting for convenience: - Lena [on whether an experiment that does not yet do any generation and only stopping]: My preferred design would be that we use a GNode to fit the model in GS, then the ESS uses this, but at the moment ESS is applied first, then GS (in Orchestrator and thus Axolotl). I think this current order is right as long as we don't merge ESS and GS (which I'd like to do eventually, including for reasons like this). So what we can do for now is just put a GNode within an ESS, then worry about the rest later. And we can just have an empty GS for now to keep the Orchestrator happy. - Sam: Yeah we could do that. Calling ESS before GS in the orchestrator fundamentally seems like the wrong order as we move toward model-based early stopping (e.g., in the conductor case). I wonder if we should merge GS and ESS sooner rather than later to resolve that, rather than create some tech debt by having ESS fit its own adapter. Eventually we decided that we would like to (later this year) merge ESS and GS, such that the actual cycle is `gen` --> cache results --> compare ROI on new vs. running trial --> make decision on stopping and running jointly. We thought some kind of `DecisionNode` might do this: {F1987649989} ---- # Claude stuff below Key changes: - Rename `BaseEarlyStoppingStrategy` → `BaseArmStoppingStrategy` (with backward-compat alias) - Rename `ModelBasedEarlyStoppingStrategy` → `ModelBasedArmStoppingStrategy` (with alias) - Change return type of `should_stop_trials_early` / `_should_stop_trials_early` from `dict[int, str | None]` → `dict[int, dict[str, str | None]]` (trial_index → {arm_name → reason}) - Remove `BatchTrial` rejection check in `is_eligible_any` - Add `_wrap_trial_results_with_arms()` helper to convert trial-level decisions to arm-level format - Update all subclasses (percentile, threshold, logical, multi_objective, quickbo) to use the new return type - Update orchestrator to check if all arms are stopped before stopping a trial (raises `NotImplementedError` for partial arm stopping) - Update `ax_client`, `api/client`, and `internal_client` to extract reasons from arm-level dict - Update all tests Differential Revision: D97304068
1 parent 4ebc621 commit c40a33f

16 files changed

Lines changed: 266 additions & 220 deletions

File tree

ax/api/client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -657,17 +657,17 @@ def should_stop_trial_early(self, trial_index: int) -> bool:
657657

658658
es_response = none_throws(
659659
self._early_stopping_strategy_or_choose()
660-
).should_stop_trials_early(
660+
).should_stop_arms(
661661
trial_indices={trial_index},
662662
experiment=self._experiment,
663663
current_node=self._generation_strategy_or_choose()._curr,
664664
)
665665

666666
if trial_index in es_response:
667-
logger.info(
668-
f"Trial {trial_index} should be stopped early: "
669-
f"{es_response[trial_index]}"
670-
)
667+
# Extract reason from arm-level decisions (use first arm's reason)
668+
arm_decisions = es_response[trial_index]
669+
reason = next(iter(arm_decisions.values())) if arm_decisions else None
670+
logger.info(f"Trial {trial_index} should be stopped early: {reason}")
671671
return True
672672

673673
return False

ax/early_stopping/strategies/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
# pyre-strict
88

99
from ax.early_stopping.strategies.base import (
10+
BaseArmStoppingStrategy,
1011
BaseEarlyStoppingStrategy,
12+
ModelBasedArmStoppingStrategy,
1113
ModelBasedEarlyStoppingStrategy,
14+
TArmsToStop,
1215
)
1316
from ax.early_stopping.strategies.logical import (
1417
AndEarlyStoppingStrategy,
@@ -20,8 +23,11 @@
2023

2124

2225
__all__ = [
26+
"BaseArmStoppingStrategy",
2327
"BaseEarlyStoppingStrategy",
28+
"ModelBasedArmStoppingStrategy",
2429
"ModelBasedEarlyStoppingStrategy",
30+
"TArmsToStop",
2531
"PercentileEarlyStoppingStrategy",
2632
"ThresholdEarlyStoppingStrategy",
2733
"AndEarlyStoppingStrategy",

ax/early_stopping/strategies/base.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import pandas as pd
1616
from ax.adapter.data_utils import _maybe_normalize_map_key
17-
from ax.core.batch_trial import BatchTrial
1817
from ax.core.data import Data, MAP_KEY
1918
from ax.core.experiment import Experiment
2019
from ax.core.trial_status import TrialStatus
@@ -36,10 +35,19 @@
3635
# backwards compatibility when loading old strategies.
3736
REMOVED_EARLY_STOPPING_STRATEGY_KWARGS: set[str] = {"trial_indices_to_ignore"}
3837

38+
# Type alias for arm-level stopping decisions:
39+
# trial_index -> {arm_name -> optional_reason}
40+
TArmsToStop = dict[int, dict[str, str | None]]
3941

40-
class BaseEarlyStoppingStrategy(ABC, Base):
42+
43+
class BaseArmStoppingStrategy(ABC, Base):
4144
"""Interface for heuristics that halt trials early, typically based on early
42-
results from that trial."""
45+
results from that trial.
46+
47+
Stopping decisions are made at the arm level: the return type of
48+
``should_stop_arms`` is ``dict[int, dict[str, str | None]]``
49+
mapping ``trial_index -> {arm_name -> optional_reason}``. For single-arm
50+
``Trial`` objects this dict will contain exactly one entry per trial."""
4351

4452
def __init__(
4553
self,
@@ -120,12 +128,12 @@ def __init__(
120128
self._last_check_progressions: dict[int, float] = {}
121129

122130
@abstractmethod
123-
def _should_stop_trials_early(
131+
def _should_stop_arms(
124132
self,
125133
trial_indices: set[int],
126134
experiment: Experiment,
127135
current_node: GenerationNode | None = None,
128-
) -> dict[int, str | None]:
136+
) -> TArmsToStop:
129137
"""Decide whether to complete trials before evaluation is fully concluded.
130138
131139
Typical examples include stopping a machine learning model's training, or
@@ -140,8 +148,9 @@ def _should_stop_trials_early(
140148
stopping decisions.
141149
142150
Returns:
143-
A dictionary mapping trial indices that should be early stopped to
144-
(optional) messages with the associated reason.
151+
A dictionary mapping trial indices to arm-level stopping decisions.
152+
Each value is a dict mapping arm names to (optional) reason strings
153+
for arms that should be stopped.
145154
"""
146155
pass
147156

@@ -163,12 +172,12 @@ def _is_harmful(
163172
"""
164173
pass
165174

166-
def should_stop_trials_early(
175+
def should_stop_arms(
167176
self,
168177
trial_indices: set[int],
169178
experiment: Experiment,
170179
current_node: GenerationNode | None = None,
171-
) -> dict[int, str | None]:
180+
) -> TArmsToStop:
172181
"""Decide whether trials should be stopped before evaluation is fully concluded.
173182
This method identifies trials that should be stopped based on early signals that
174183
are indicative of final performance. Early stopping is not applied if doing so
@@ -183,17 +192,17 @@ def should_stop_trials_early(
183192
stopping decisions.
184193
185194
Returns:
186-
A dictionary mapping trial indices that should be early stopped to
187-
(optional) messages with the associated reason. Returns an empty
188-
dictionary if early stopping would be harmful (when safety check is
189-
enabled).
195+
A dictionary mapping trial indices to arm-level stopping decisions.
196+
Each value is a dict mapping arm names to (optional) reason strings
197+
for arms that should be stopped. Returns an empty dictionary if
198+
early stopping would be harmful (when safety check is enabled).
190199
"""
191200
if self.check_safe and self._is_harmful(
192201
trial_indices=trial_indices,
193202
experiment=experiment,
194203
):
195204
return {}
196-
return self._should_stop_trials_early(
205+
return self._should_stop_arms(
197206
trial_indices=trial_indices,
198207
experiment=experiment,
199208
current_node=current_node,
@@ -340,17 +349,6 @@ def is_eligible_any(
340349
then we can skip costly steps, such as model fitting, that occur before
341350
individual trials are considered for stopping.
342351
"""
343-
# check for batch trials
344-
for idx, trial in experiment.trials.items():
345-
if isinstance(trial, BatchTrial):
346-
# In particular, align_partial_results requires a 1-1 mapping between
347-
# trial indices and arm names, which is not the case for batch trials.
348-
# See align_partial_results for more details.
349-
raise ValueError(
350-
f"Trial {idx} is a BatchTrial, which is not yet supported by "
351-
"early stopping strategies."
352-
)
353-
354352
# check that there are sufficient completed trials
355353
num_completed = len(experiment.trial_indices_by_status[TrialStatus.COMPLETED])
356354
if self.min_curves is not None and num_completed < self.min_curves:
@@ -585,7 +583,7 @@ def _prepare_aligned_data(
585583
return long_df, multilevel_wide_df
586584

587585

588-
class ModelBasedEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
586+
class ModelBasedArmStoppingStrategy(BaseArmStoppingStrategy):
589587
"""A base class for model based early stopping strategies. Includes
590588
a helper function for processing Data into arrays."""
591589

@@ -666,3 +664,8 @@ def _lookup_and_validate_data(
666664
full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling]
667665
map_data = Data(df=full_df)
668666
return map_data
667+
668+
669+
# Deprecated aliases for backward compatibility.
670+
BaseEarlyStoppingStrategy = BaseArmStoppingStrategy
671+
ModelBasedEarlyStoppingStrategy = ModelBasedArmStoppingStrategy

ax/early_stopping/strategies/logical.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from functools import reduce
1010

1111
from ax.core.experiment import Experiment
12-
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
12+
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy, TArmsToStop
1313
from ax.exceptions.core import UserInputError
1414
from ax.generation_strategy.generation_node import GenerationNode
1515

@@ -41,25 +41,35 @@ def _is_harmful(
4141
experiment=experiment,
4242
)
4343

44-
def _should_stop_trials_early(
44+
def _should_stop_arms(
4545
self,
4646
trial_indices: set[int],
4747
experiment: Experiment,
4848
current_node: GenerationNode | None = None,
49-
) -> dict[int, str | None]:
50-
left = self.left.should_stop_trials_early(
49+
) -> TArmsToStop:
50+
left = self.left.should_stop_arms(
5151
trial_indices=trial_indices,
5252
experiment=experiment,
5353
current_node=current_node,
5454
)
55-
right = self.right.should_stop_trials_early(
55+
right = self.right.should_stop_arms(
5656
trial_indices=trial_indices,
5757
experiment=experiment,
5858
current_node=current_node,
5959
)
60-
return {
61-
trial: f"{left[trial]}, {right[trial]}" for trial in left if trial in right
62-
}
60+
# Combine at the arm level: only stop arms that both strategies agree on
61+
result: TArmsToStop = {}
62+
for trial in left:
63+
if trial in right:
64+
combined_arms: dict[str, str | None] = {}
65+
for arm_name in left[trial]:
66+
if arm_name in right[trial]:
67+
combined_arms[arm_name] = (
68+
f"{left[trial][arm_name]}, {right[trial][arm_name]}"
69+
)
70+
if combined_arms:
71+
result[trial] = combined_arms
72+
return result
6373

6474

6575
class OrEarlyStoppingStrategy(LogicalEarlyStoppingStrategy):
@@ -91,21 +101,36 @@ def _is_harmful(
91101
experiment=experiment,
92102
)
93103

94-
def _should_stop_trials_early(
104+
def _should_stop_arms(
95105
self,
96106
trial_indices: set[int],
97107
experiment: Experiment,
98108
current_node: GenerationNode | None = None,
99-
) -> dict[int, str | None]:
100-
return {
101-
**self.left.should_stop_trials_early(
102-
trial_indices=trial_indices,
103-
experiment=experiment,
104-
current_node=current_node,
105-
),
106-
**self.right.should_stop_trials_early(
107-
trial_indices=trial_indices,
108-
experiment=experiment,
109-
current_node=current_node,
110-
),
111-
}
109+
) -> TArmsToStop:
110+
left = self.left.should_stop_arms(
111+
trial_indices=trial_indices,
112+
experiment=experiment,
113+
current_node=current_node,
114+
)
115+
right = self.right.should_stop_arms(
116+
trial_indices=trial_indices,
117+
experiment=experiment,
118+
current_node=current_node,
119+
)
120+
# Merge at arm level: stop arms that either strategy wants to stop
121+
result: TArmsToStop = {}
122+
all_trials = set(left) | set(right)
123+
for trial in all_trials:
124+
left_arms = left.get(trial, {})
125+
right_arms = right.get(trial, {})
126+
merged_arms: dict[str, str | None] = {}
127+
for arm_name in set(left_arms) | set(right_arms):
128+
reasons = []
129+
if arm_name in left_arms and left_arms[arm_name] is not None:
130+
reasons.append(left_arms[arm_name])
131+
if arm_name in right_arms and right_arms[arm_name] is not None:
132+
reasons.append(right_arms[arm_name])
133+
merged_arms[arm_name] = ", ".join(reasons) if reasons else None
134+
if merged_arms:
135+
result[trial] = merged_arms
136+
return result

ax/early_stopping/strategies/percentile.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ax.core.experiment import Experiment
1414
from ax.core.trial_status import TrialStatus
1515
from ax.early_stopping.simulation import best_trial_vulnerable
16-
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
16+
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy, TArmsToStop
1717
from ax.early_stopping.utils import _is_worse
1818
from ax.exceptions.core import UnsupportedError, UserInputError
1919
from ax.generation_strategy.generation_node import GenerationNode
@@ -175,12 +175,12 @@ def _is_harmful(
175175

176176
return simulated_result.best_stopped
177177

178-
def _should_stop_trials_early(
178+
def _should_stop_arms(
179179
self,
180180
trial_indices: set[int],
181181
experiment: Experiment,
182182
current_node: GenerationNode | None = None,
183-
) -> dict[int, str | None]:
183+
) -> TArmsToStop:
184184
"""Stop a trial if its performance is in the bottom `percentile_threshold`
185185
of the trials at the same step.
186186
@@ -193,9 +193,9 @@ def _should_stop_trials_early(
193193
stopping decisions.
194194
195195
Returns:
196-
A dictionary mapping trial indices that should be early stopped to
197-
(optional) messages with the associated reason. An empty dictionary
198-
means no suggested updates to any trial's status.
196+
A dictionary mapping trial indices to arm-level stopping decisions.
197+
Each value is a dict mapping arm names to (optional) reason strings.
198+
An empty dictionary means no suggested updates.
199199
"""
200200
metric_signature, minimize = self._default_objective_and_direction(
201201
experiment=experiment
@@ -216,21 +216,19 @@ def _should_stop_trials_early(
216216
):
217217
return {}
218218

219-
decisions = {
220-
trial_index: self._should_stop_trial_early(
219+
result: TArmsToStop = {}
220+
for trial_index in trial_indices:
221+
should_stop, reason = self._should_stop_trial_early(
221222
trial_index=trial_index,
222223
experiment=experiment,
223224
wide_df=wide_df,
224225
long_df=long_df,
225226
minimize=minimize,
226227
)
227-
for trial_index in trial_indices
228-
}
229-
return {
230-
trial_index: reason
231-
for trial_index, (should_stop, reason) in decisions.items()
232-
if should_stop
233-
}
228+
if should_stop:
229+
trial = experiment.trials[trial_index]
230+
result[trial_index] = {a.name: reason for a in trial.arms}
231+
return result
234232

235233
def _should_stop_trial_early(
236234
self,
@@ -287,7 +285,7 @@ def _should_stop_trial_early(
287285
window_num_active_trials: pd.Series = window_active_trials.sum(axis=1)
288286

289287
# Verify that sufficiently many trials have data at each progression in
290-
# the patience window. Note: `is_eligible_any` in `should_stop_trials_early`
288+
# the patience window. Note: `is_eligible_any` in `should_stop_arms`
291289
# already checks that at least `min_curves` trials have completed and uses
292290
# `align_partial_results` to interpolate missing values. This condition
293291
# should only trigger if `align_partial_results` fails or if this method

0 commit comments

Comments
 (0)