Skip to content

Commit 546a28b

Browse files
ltiaofacebook-github-bot
authored andcommitted
Rewrite MapDataReplayMetric and MapDataReplayRunner to use MapDataReplayState (#5138)
Summary: The experiment replay system (`MapDataReplayMetric`, `MapDataReplayRunner`, `replay_experiment`) is hardcoded for single-objective optimization, blocking multi-objective early stopping. `MapDataReplayMetric` conflates data serving with progression state, so multiple metrics cannot share a coherent timeline. This diff series extracts shared state into a `MapDataReplayState` coordinator. This diff (2/3) rewrites `MapDataReplayMetric` and `MapDataReplayRunner` to delegate to the shared `MapDataReplayState` introduced in D98741817. Also updates `replay_experiment` and `estimate_hypothetical_early_stopping_savings` for multi-metric support: accepts `metrics: list[Metric]`, builds single-objective or multi-objective `OptimizationConfig`, extracts objective thresholds from the historical config, re-indexes non-contiguous trial indices, and deprecates `num_samples_per_curve`. Backward-compat `metric=` kwarg preserved for downstream callers (removed in D98741814). Differential Revision: D98741816
1 parent a2eac73 commit 546a28b

6 files changed

Lines changed: 579 additions & 246 deletions

File tree

ax/early_stopping/experiment_replay.py

Lines changed: 116 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,20 @@
77
# pyre-strict
88

99
import logging
10+
import warnings
1011
from logging import Logger
1112
from time import perf_counter
1213

1314
from ax.adapter.registry import Generators
15+
from ax.core.data import Data
1416
from ax.core.experiment import Experiment
1517
from ax.core.metric import Metric
16-
from ax.core.objective import Objective
17-
from ax.core.optimization_config import OptimizationConfig
18+
from ax.core.objective import MultiObjective, Objective
19+
from ax.core.optimization_config import (
20+
MultiObjectiveOptimizationConfig,
21+
OptimizationConfig,
22+
)
23+
from ax.core.outcome_constraint import OutcomeConstraint
1824
from ax.core.parameter import ParameterType, RangeParameter
1925
from ax.core.search_space import SearchSpace
2026
from ax.early_stopping.dispatch import get_default_ess_or_none
@@ -25,7 +31,7 @@
2531
GenerationStep,
2632
GenerationStrategy,
2733
)
28-
from ax.metrics.map_replay import MapDataReplayMetric
34+
from ax.metrics.map_replay import MapDataReplayMetric, MapDataReplayState
2935
from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions
3036
from ax.runners.map_replay import MapDataReplayRunner
3137
from ax.utils.common.logger import get_logger
@@ -43,16 +49,51 @@ def replay_experiment(
4349
historical_experiment: Experiment,
4450
num_samples_per_curve: int,
4551
max_replay_trials: int,
46-
metric: Metric,
47-
max_pending_trials: int,
48-
early_stopping_strategy: BaseEarlyStoppingStrategy | None,
52+
metrics: list[Metric] | None = None,
53+
max_pending_trials: int = MAX_PENDING_TRIALS,
54+
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None,
4955
logging_level: int = logging.ERROR,
56+
# Deprecated backward-compat kwarg
57+
metric: Metric | None = None,
5058
) -> Experiment | None:
51-
"""A utility function for replaying a historical experiment's data
52-
by initializing a Orchestrator that quickly steps through the existing data.
53-
The main purpose of this function is to compute an hypothetical capacity
54-
savings for a given `early_stopping_strategy`.
59+
"""Replay a historical experiment's data through an Orchestrator.
60+
61+
Initializes an Orchestrator that steps through existing data to compute
62+
hypothetical capacity savings for a given ``early_stopping_strategy``.
63+
Supports both single-objective and multi-objective optimization.
64+
65+
Args:
66+
historical_experiment: The experiment whose data to replay.
67+
num_samples_per_curve: Deprecated. Number of samples per curve for
68+
subsampling. Use ``step_size`` on ``MapDataReplayState`` instead.
69+
max_replay_trials: Maximum number of trials to replay.
70+
metrics: List of metrics to replay. For single-objective, provide
71+
one metric. For multi-objective, provide multiple metrics.
72+
max_pending_trials: Maximum number of pending trials for the
73+
replay orchestrator.
74+
early_stopping_strategy: The early stopping strategy to evaluate.
75+
logging_level: Logging level for the orchestrator.
76+
metric: Deprecated. Use ``metrics`` instead.
5577
"""
78+
# Backward compat: accept metric= (singular) and wrap in list
79+
if metric is not None:
80+
warnings.warn(
81+
"The `metric` parameter is deprecated. Use `metrics` instead.",
82+
DeprecationWarning,
83+
stacklevel=2,
84+
)
85+
if metrics is not None:
86+
raise ValueError("Cannot specify both `metric` and `metrics`.")
87+
metrics = [metric]
88+
if metrics is None:
89+
raise ValueError("Must specify `metrics`.")
90+
warnings.warn(
91+
"The `num_samples_per_curve` parameter is deprecated and will be "
92+
"removed in a future release. The `step_size` parameter on "
93+
"`MapDataReplayState` controls replay granularity.",
94+
DeprecationWarning,
95+
stacklevel=2,
96+
)
5697
historical_map_data = historical_experiment.lookup_data()
5798
if not historical_map_data.has_step_column:
5899
logger.warning(
@@ -62,16 +103,51 @@ def replay_experiment(
62103
historical_map_data = historical_map_data.subsample(
63104
limit_rows_per_group=num_samples_per_curve, include_first_last=True
64105
)
65-
replay_metric = MapDataReplayMetric(
66-
name=f"replay_{historical_experiment.name}",
67-
map_data=historical_map_data,
68-
metric_name=metric.name,
69-
lower_is_better=metric.lower_is_better,
70-
)
71-
optimization_config = OptimizationConfig(
72-
objective=Objective(metric=replay_metric),
106+
107+
# Re-index non-contiguous trial indices to contiguous 0, 1, 2, ...
108+
# so that replay trial N maps to the Nth historical trial.
109+
df = historical_map_data.full_df
110+
sorted_trial_indices = sorted(df["trial_index"].unique())
111+
trial_index_map = {old: new for new, old in enumerate(sorted_trial_indices)}
112+
df = df.copy()
113+
df["trial_index"] = df["trial_index"].map(trial_index_map)
114+
historical_map_data = Data(df=df)
115+
116+
metric_signatures = [m.signature for m in metrics]
117+
replay_state = MapDataReplayState(
118+
map_data=historical_map_data, metric_signatures=metric_signatures
73119
)
74-
runner = MapDataReplayRunner(replay_metric=replay_metric)
120+
121+
replay_metrics = [
122+
MapDataReplayMetric(
123+
name=m.name,
124+
replay_state=replay_state,
125+
metric_signature=m.signature,
126+
lower_is_better=m.lower_is_better,
127+
)
128+
for m in metrics
129+
]
130+
131+
if len(replay_metrics) == 1:
132+
optimization_config: OptimizationConfig = OptimizationConfig(
133+
objective=Objective(metric=replay_metrics[0]),
134+
)
135+
else:
136+
# Extract objective thresholds from the historical experiment's config
137+
historical_opt_config = historical_experiment.optimization_config
138+
objective_thresholds: list[OutcomeConstraint] = []
139+
if isinstance(historical_opt_config, MultiObjectiveOptimizationConfig):
140+
objective_thresholds = [
141+
ot.clone() for ot in historical_opt_config.objective_thresholds
142+
]
143+
optimization_config = MultiObjectiveOptimizationConfig(
144+
objective=MultiObjective(
145+
objectives=[Objective(metric=m) for m in replay_metrics]
146+
),
147+
objective_thresholds=objective_thresholds,
148+
)
149+
150+
runner = MapDataReplayRunner(replay_state=replay_state)
75151

76152
# Setup a new experiment with a dummy search space
77153
dummy_search_space = SearchSpace(
@@ -89,10 +165,10 @@ def replay_experiment(
89165
optimization_config=optimization_config,
90166
search_space=dummy_search_space,
91167
runner=runner,
92-
metrics=[replay_metric],
168+
metrics=replay_metrics,
93169
)
94170

95-
# Setup a Orchestrator with a dummy gs to replay the historical experiment
171+
# Setup an Orchestrator with a dummy gs to replay the historical experiment
96172
dummy_sobol_gs = GenerationStrategy(
97173
name="sobol",
98174
steps=[
@@ -101,7 +177,7 @@ def replay_experiment(
101177
)
102178
options = OrchestratorOptions(
103179
max_pending_trials=max_pending_trials,
104-
total_trials=min(len(historical_experiment.trials), max_replay_trials),
180+
total_trials=min(len(sorted_trial_indices), max_replay_trials),
105181
seconds_between_polls_backoff_factor=1.0,
106182
min_seconds_before_poll=0.0,
107183
init_seconds_between_polls=0,
@@ -119,8 +195,10 @@ def replay_experiment(
119195

120196
def estimate_hypothetical_early_stopping_savings(
121197
experiment: Experiment,
122-
metric: Metric,
198+
metrics: list[Metric] | None = None,
123199
max_pending_trials: int = MAX_PENDING_TRIALS,
200+
# Deprecated backward-compat kwarg
201+
metric: Metric | None = None,
124202
) -> float:
125203
"""Estimate hypothetical early stopping savings using experiment replay.
126204
@@ -130,9 +208,10 @@ def estimate_hypothetical_early_stopping_savings(
130208
131209
Args:
132210
experiment: The experiment to analyze.
133-
metric: The metric to use for early stopping replay.
211+
metrics: The metrics to use for early stopping replay.
134212
max_pending_trials: Maximum number of pending trials for the replay
135213
orchestrator. Defaults to 5.
214+
metric: Deprecated. Use ``metrics`` instead.
136215
137216
Returns:
138217
Estimated savings as a fraction (0.0 to 1.0).
@@ -145,6 +224,18 @@ def estimate_hypothetical_early_stopping_savings(
145224
- The experiment data does not have progression data for replay
146225
- The experiment replay fails due to invalid experiment state
147226
"""
227+
# Backward compat: accept metric= (singular) and wrap in list
228+
if metric is not None:
229+
warnings.warn(
230+
"The `metric` parameter is deprecated. Use `metrics` instead.",
231+
DeprecationWarning,
232+
stacklevel=2,
233+
)
234+
if metrics is not None:
235+
raise ValueError("Cannot specify both `metric` and `metrics`.")
236+
metrics = [metric]
237+
if metrics is None:
238+
raise ValueError("Must specify `metrics`.")
148239
default_ess = get_default_ess_or_none(experiment=experiment)
149240
if default_ess is None:
150241
raise UnsupportedError(
@@ -156,7 +247,7 @@ def estimate_hypothetical_early_stopping_savings(
156247
historical_experiment=experiment,
157248
num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE,
158249
max_replay_trials=MAX_REPLAY_TRIALS,
159-
metric=metric,
250+
metrics=metrics,
160251
max_pending_trials=max_pending_trials,
161252
early_stopping_strategy=default_ess,
162253
)

0 commit comments

Comments
 (0)