77# pyre-strict
88
99import logging
10+ import warnings
1011from logging import Logger
1112from time import perf_counter
1213
1314from ax .adapter .registry import Generators
15+ from ax .core .data import Data
1416from ax .core .experiment import Experiment
1517from ax .core .metric import Metric
16- from ax .core .objective import Objective
17- from ax .core .optimization_config import OptimizationConfig
18+ from ax .core .objective import MultiObjective , Objective
19+ from ax .core .optimization_config import (
20+ MultiObjectiveOptimizationConfig ,
21+ OptimizationConfig ,
22+ )
23+ from ax .core .outcome_constraint import OutcomeConstraint
1824from ax .core .parameter import ParameterType , RangeParameter
1925from ax .core .search_space import SearchSpace
2026from ax .early_stopping .dispatch import get_default_ess_or_none
2531 GenerationStep ,
2632 GenerationStrategy ,
2733)
28- from ax .metrics .map_replay import MapDataReplayMetric
34+ from ax .metrics .map_replay import MapDataReplayMetric , MapDataReplayState
2935from ax .orchestration .orchestrator import Orchestrator , OrchestratorOptions
3036from ax .runners .map_replay import MapDataReplayRunner
3137from ax .utils .common .logger import get_logger
@@ -43,16 +49,51 @@ def replay_experiment(
4349 historical_experiment : Experiment ,
4450 num_samples_per_curve : int ,
4551 max_replay_trials : int ,
46- metric : Metric ,
47- max_pending_trials : int ,
48- early_stopping_strategy : BaseEarlyStoppingStrategy | None ,
52+ metrics : list [ Metric ] | None = None ,
53+ max_pending_trials : int = MAX_PENDING_TRIALS ,
54+ early_stopping_strategy : BaseEarlyStoppingStrategy | None = None ,
4955 logging_level : int = logging .ERROR ,
56+ # Deprecated backward-compat kwarg
57+ metric : Metric | None = None ,
5058) -> Experiment | None :
51- """A utility function for replaying a historical experiment's data
52- by initializing a Orchestrator that quickly steps through the existing data.
53- The main purpose of this function is to compute an hypothetical capacity
54- savings for a given `early_stopping_strategy`.
59+ """Replay a historical experiment's data through an Orchestrator.
60+
61+ Initializes an Orchestrator that steps through existing data to compute
62+ hypothetical capacity savings for a given ``early_stopping_strategy``.
63+ Supports both single-objective and multi-objective optimization.
64+
65+ Args:
66+ historical_experiment: The experiment whose data to replay.
67+ num_samples_per_curve: Deprecated. Number of samples per curve for
68+ subsampling. Use ``step_size`` on ``MapDataReplayState`` instead.
69+ max_replay_trials: Maximum number of trials to replay.
70+ metrics: List of metrics to replay. For single-objective, provide
71+ one metric. For multi-objective, provide multiple metrics.
72+ max_pending_trials: Maximum number of pending trials for the
73+ replay orchestrator.
74+ early_stopping_strategy: The early stopping strategy to evaluate.
75+ logging_level: Logging level for the orchestrator.
76+ metric: Deprecated. Use ``metrics`` instead.
5577 """
78+ # Backward compat: accept metric= (singular) and wrap in list
79+ if metric is not None :
80+ warnings .warn (
81+ "The `metric` parameter is deprecated. Use `metrics` instead." ,
82+ DeprecationWarning ,
83+ stacklevel = 2 ,
84+ )
85+ if metrics is not None :
86+ raise ValueError ("Cannot specify both `metric` and `metrics`." )
87+ metrics = [metric ]
88+ if metrics is None :
89+ raise ValueError ("Must specify `metrics`." )
90+ warnings .warn (
91+ "The `num_samples_per_curve` parameter is deprecated and will be "
92+ "removed in a future release. The `step_size` parameter on "
93+ "`MapDataReplayState` controls replay granularity." ,
94+ DeprecationWarning ,
95+ stacklevel = 2 ,
96+ )
5697 historical_map_data = historical_experiment .lookup_data ()
5798 if not historical_map_data .has_step_column :
5899 logger .warning (
@@ -62,16 +103,51 @@ def replay_experiment(
62103 historical_map_data = historical_map_data .subsample (
63104 limit_rows_per_group = num_samples_per_curve , include_first_last = True
64105 )
65- replay_metric = MapDataReplayMetric (
66- name = f"replay_{ historical_experiment .name } " ,
67- map_data = historical_map_data ,
68- metric_name = metric .name ,
69- lower_is_better = metric .lower_is_better ,
70- )
71- optimization_config = OptimizationConfig (
72- objective = Objective (metric = replay_metric ),
106+
107+ # Re-index non-contiguous trial indices to contiguous 0, 1, 2, ...
108+ # so that replay trial N maps to the Nth historical trial.
109+ df = historical_map_data .full_df
110+ sorted_trial_indices = sorted (df ["trial_index" ].unique ())
111+ trial_index_map = {old : new for new , old in enumerate (sorted_trial_indices )}
112+ df = df .copy ()
113+ df ["trial_index" ] = df ["trial_index" ].map (trial_index_map )
114+ historical_map_data = Data (df = df )
115+
116+ metric_signatures = [m .signature for m in metrics ]
117+ replay_state = MapDataReplayState (
118+ map_data = historical_map_data , metric_signatures = metric_signatures
73119 )
74- runner = MapDataReplayRunner (replay_metric = replay_metric )
120+
121+ replay_metrics = [
122+ MapDataReplayMetric (
123+ name = m .name ,
124+ replay_state = replay_state ,
125+ metric_signature = m .signature ,
126+ lower_is_better = m .lower_is_better ,
127+ )
128+ for m in metrics
129+ ]
130+
131+ if len (replay_metrics ) == 1 :
132+ optimization_config : OptimizationConfig = OptimizationConfig (
133+ objective = Objective (metric = replay_metrics [0 ]),
134+ )
135+ else :
136+ # Extract objective thresholds from the historical experiment's config
137+ historical_opt_config = historical_experiment .optimization_config
138+ objective_thresholds : list [OutcomeConstraint ] = []
139+ if isinstance (historical_opt_config , MultiObjectiveOptimizationConfig ):
140+ objective_thresholds = [
141+ ot .clone () for ot in historical_opt_config .objective_thresholds
142+ ]
143+ optimization_config = MultiObjectiveOptimizationConfig (
144+ objective = MultiObjective (
145+ objectives = [Objective (metric = m ) for m in replay_metrics ]
146+ ),
147+ objective_thresholds = objective_thresholds ,
148+ )
149+
150+ runner = MapDataReplayRunner (replay_state = replay_state )
75151
76152 # Setup a new experiment with a dummy search space
77153 dummy_search_space = SearchSpace (
@@ -89,10 +165,10 @@ def replay_experiment(
89165 optimization_config = optimization_config ,
90166 search_space = dummy_search_space ,
91167 runner = runner ,
92- metrics = [ replay_metric ] ,
168+ metrics = replay_metrics ,
93169 )
94170
95- # Setup a Orchestrator with a dummy gs to replay the historical experiment
171+ # Setup an Orchestrator with a dummy gs to replay the historical experiment
96172 dummy_sobol_gs = GenerationStrategy (
97173 name = "sobol" ,
98174 steps = [
@@ -101,7 +177,7 @@ def replay_experiment(
101177 )
102178 options = OrchestratorOptions (
103179 max_pending_trials = max_pending_trials ,
104- total_trials = min (len (historical_experiment . trials ), max_replay_trials ),
180+ total_trials = min (len (sorted_trial_indices ), max_replay_trials ),
105181 seconds_between_polls_backoff_factor = 1.0 ,
106182 min_seconds_before_poll = 0.0 ,
107183 init_seconds_between_polls = 0 ,
@@ -119,8 +195,10 @@ def replay_experiment(
119195
120196def estimate_hypothetical_early_stopping_savings (
121197 experiment : Experiment ,
122- metric : Metric ,
198+ metrics : list [ Metric ] | None = None ,
123199 max_pending_trials : int = MAX_PENDING_TRIALS ,
200+ # Deprecated backward-compat kwarg
201+ metric : Metric | None = None ,
124202) -> float :
125203 """Estimate hypothetical early stopping savings using experiment replay.
126204
@@ -130,9 +208,10 @@ def estimate_hypothetical_early_stopping_savings(
130208
131209 Args:
132210 experiment: The experiment to analyze.
133- metric : The metric to use for early stopping replay.
211+ metrics : The metrics to use for early stopping replay.
134212 max_pending_trials: Maximum number of pending trials for the replay
135213 orchestrator. Defaults to 5.
214+ metric: Deprecated. Use ``metrics`` instead.
136215
137216 Returns:
138217 Estimated savings as a fraction (0.0 to 1.0).
@@ -145,6 +224,18 @@ def estimate_hypothetical_early_stopping_savings(
145224 - The experiment data does not have progression data for replay
146225 - The experiment replay fails due to invalid experiment state
147226 """
227+ # Backward compat: accept metric= (singular) and wrap in list
228+ if metric is not None :
229+ warnings .warn (
230+ "The `metric` parameter is deprecated. Use `metrics` instead." ,
231+ DeprecationWarning ,
232+ stacklevel = 2 ,
233+ )
234+ if metrics is not None :
235+ raise ValueError ("Cannot specify both `metric` and `metrics`." )
236+ metrics = [metric ]
237+ if metrics is None :
238+ raise ValueError ("Must specify `metrics`." )
148239 default_ess = get_default_ess_or_none (experiment = experiment )
149240 if default_ess is None :
150241 raise UnsupportedError (
@@ -156,7 +247,7 @@ def estimate_hypothetical_early_stopping_savings(
156247 historical_experiment = experiment ,
157248 num_samples_per_curve = REPLAY_NUM_POINTS_PER_CURVE ,
158249 max_replay_trials = MAX_REPLAY_TRIALS ,
159- metric = metric ,
250+ metrics = metrics ,
160251 max_pending_trials = max_pending_trials ,
161252 early_stopping_strategy = default_ess ,
162253 )
0 commit comments