Skip to content

Commit a2eac73

Browse files
ltiaofacebook-github-bot
authored andcommitted
Add MapDataReplayState coordinator class (#5137)
Summary: The experiment replay system (`MapDataReplayMetric`, `MapDataReplayRunner`, `replay_experiment`) is hardcoded for single-objective optimization, blocking multi-objective early stopping. `MapDataReplayMetric` conflates data serving with progression state, so multiple metrics cannot share a coherent timeline. This diff series extracts shared state into a `MapDataReplayState` coordinator. This diff (1/3) adds the `MapDataReplayState` class -- a shared state coordinator that manages normalized cursor-based progression across multiple metrics and trials. Uses a global min/max MAP_KEY to preserve cross-metric timing alignment. Serves original MAP_KEY values; downstream ESS normalizes independently via `_maybe_normalize_map_key`. Existing `MapDataReplayMetric` and helpers are unchanged in this diff. Differential Revision: D98741817
1 parent 227243c commit a2eac73

2 files changed

Lines changed: 339 additions & 20 deletions

File tree

ax/metrics/map_replay.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,109 @@
2323
logger: Logger = get_logger(__name__)
2424

2525

26+
class MapDataReplayState:
27+
"""Shared state coordinator for replaying historical map data.
28+
29+
Manages normalized cursor-based progression across multiple metrics
30+
and trials. The cursor model uses a global min/max MAP_KEY across
31+
all metrics to preserve cross-metric timing alignment.
32+
33+
This class serves original MAP_KEY values (not normalized). Downstream
34+
early stopping strategies apply normalization independently via
35+
``_maybe_normalize_map_key`` in ``ax.adapter.data_utils``.
36+
"""
37+
38+
def __init__(
39+
self,
40+
map_data: Data,
41+
metric_signatures: list[str],
42+
step_size: float = 0.01,
43+
) -> None:
44+
"""Initialize replay state from historical data.
45+
46+
Args:
47+
map_data: Historical data containing progression data.
48+
metric_signatures: List of metric signatures to replay.
49+
step_size: Cursor increment per advancement step. Determines
50+
the granularity of replay (e.g. 0.01 = 100 steps).
51+
"""
52+
self.step_size: float = step_size
53+
54+
# Pre-index data by (trial_index, metric_signature) for O(1) lookups
55+
self._data: dict[tuple[int, str], pd.DataFrame] = {}
56+
all_trial_indices: set[int] = set()
57+
all_prog_values: list[float] = []
58+
per_trial_max_prog: dict[int, float] = {}
59+
60+
for metric_signature in metric_signatures:
61+
df = map_data.full_df
62+
df = df[df["metric_signature"] == metric_signature]
63+
replay_df = df.sort_values(
64+
by=["trial_index", MAP_KEY], ascending=True, ignore_index=True
65+
)
66+
for trial_index, group in replay_df.groupby("trial_index"):
67+
trial_index = int(trial_index)
68+
self._data[(trial_index, metric_signature)] = group.reset_index(
69+
drop=True
70+
)
71+
all_trial_indices.add(trial_index)
72+
prog_values = group[MAP_KEY].values
73+
all_prog_values.extend(prog_values.tolist())
74+
trial_max = float(prog_values.max())
75+
if trial_index in per_trial_max_prog:
76+
per_trial_max_prog[trial_index] = max(
77+
per_trial_max_prog[trial_index], trial_max
78+
)
79+
else:
80+
per_trial_max_prog[trial_index] = trial_max
81+
82+
if all_prog_values:
83+
self.min_prog: float = float(min(all_prog_values))
84+
self.max_prog: float = float(max(all_prog_values))
85+
else:
86+
self.min_prog = 0.0
87+
self.max_prog = 0.0
88+
89+
self._per_trial_max_prog: dict[int, float] = per_trial_max_prog
90+
self._trial_cursors: defaultdict[int, float] = defaultdict(float)
91+
self._trial_indices: set[int] = all_trial_indices
92+
93+
def advance_trial(self, trial_index: int) -> None:
94+
"""Advance the cursor for a trial by one step."""
95+
self._trial_cursors[trial_index] = min(
96+
self._trial_cursors[trial_index] + self.step_size, 1.0
97+
)
98+
99+
def has_trial_data(self, trial_index: int) -> bool:
100+
"""Check if any replay data exists for a given trial."""
101+
return trial_index in self._trial_indices
102+
103+
def is_trial_complete(self, trial_index: int) -> bool:
104+
"""Check if a trial's cursor has reached its maximum progression."""
105+
if self.min_prog == self.max_prog:
106+
return True
107+
curr_prog = self.min_prog + self._trial_cursors[trial_index] * (
108+
self.max_prog - self.min_prog
109+
)
110+
return curr_prog >= self._per_trial_max_prog.get(trial_index, 0.0)
111+
112+
def get_data(self, trial_index: int, metric_signature: str) -> pd.DataFrame:
113+
"""Get replay data for a trial up to the current cursor position.
114+
115+
Returns a DataFrame filtered to rows where MAP_KEY <= current
116+
progression value, with original (non-normalized) MAP_KEY values.
117+
"""
118+
df = self._data.get((trial_index, metric_signature))
119+
if df is None:
120+
return pd.DataFrame()
121+
if self.min_prog == self.max_prog:
122+
return df
123+
curr_prog = self.min_prog + self._trial_cursors[trial_index] * (
124+
self.max_prog - self.min_prog
125+
)
126+
return df[df[MAP_KEY] <= curr_prog]
127+
128+
26129
class MapDataReplayMetric(MapMetric):
27130
"""A metric for replaying historical map data."""
28131

0 commit comments

Comments
 (0)