|
23 | 23 | logger: Logger = get_logger(__name__) |
24 | 24 |
|
25 | 25 |
|
| 26 | +class MapDataReplayState: |
| 27 | + """Shared state coordinator for replaying historical map data. |
| 28 | +
|
| 29 | + Manages normalized cursor-based progression across multiple metrics |
| 30 | + and trials. The cursor model uses a global min/max MAP_KEY across |
| 31 | + all metrics to preserve cross-metric timing alignment. |
| 32 | +
|
| 33 | + This class serves original MAP_KEY values (not normalized). Downstream |
| 34 | + early stopping strategies apply normalization independently via |
| 35 | + ``_maybe_normalize_map_key`` in ``ax.adapter.data_utils``. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + map_data: Data, |
| 41 | + metric_signatures: list[str], |
| 42 | + step_size: float = 0.01, |
| 43 | + ) -> None: |
| 44 | + """Initialize replay state from historical data. |
| 45 | +
|
| 46 | + Args: |
| 47 | + map_data: Historical data containing progression data. |
| 48 | + metric_signatures: List of metric signatures to replay. |
| 49 | + step_size: Cursor increment per advancement step. Determines |
| 50 | + the granularity of replay (e.g. 0.01 = 100 steps). |
| 51 | + """ |
| 52 | + self.step_size: float = step_size |
| 53 | + |
| 54 | + # Pre-index data by (trial_index, metric_signature) for O(1) lookups |
| 55 | + self._data: dict[tuple[int, str], pd.DataFrame] = {} |
| 56 | + all_trial_indices: set[int] = set() |
| 57 | + all_prog_values: list[float] = [] |
| 58 | + per_trial_max_prog: dict[int, float] = {} |
| 59 | + |
| 60 | + for metric_signature in metric_signatures: |
| 61 | + df = map_data.full_df |
| 62 | + df = df[df["metric_signature"] == metric_signature] |
| 63 | + replay_df = df.sort_values( |
| 64 | + by=["trial_index", MAP_KEY], ascending=True, ignore_index=True |
| 65 | + ) |
| 66 | + for trial_index, group in replay_df.groupby("trial_index"): |
| 67 | + trial_index = int(trial_index) |
| 68 | + self._data[(trial_index, metric_signature)] = group.reset_index( |
| 69 | + drop=True |
| 70 | + ) |
| 71 | + all_trial_indices.add(trial_index) |
| 72 | + prog_values = group[MAP_KEY].values |
| 73 | + all_prog_values.extend(prog_values.tolist()) |
| 74 | + trial_max = float(prog_values.max()) |
| 75 | + if trial_index in per_trial_max_prog: |
| 76 | + per_trial_max_prog[trial_index] = max( |
| 77 | + per_trial_max_prog[trial_index], trial_max |
| 78 | + ) |
| 79 | + else: |
| 80 | + per_trial_max_prog[trial_index] = trial_max |
| 81 | + |
| 82 | + if all_prog_values: |
| 83 | + self.min_prog: float = float(min(all_prog_values)) |
| 84 | + self.max_prog: float = float(max(all_prog_values)) |
| 85 | + else: |
| 86 | + self.min_prog = 0.0 |
| 87 | + self.max_prog = 0.0 |
| 88 | + |
| 89 | + self._per_trial_max_prog: dict[int, float] = per_trial_max_prog |
| 90 | + self._trial_cursors: defaultdict[int, float] = defaultdict(float) |
| 91 | + self._trial_indices: set[int] = all_trial_indices |
| 92 | + |
| 93 | + def advance_trial(self, trial_index: int) -> None: |
| 94 | + """Advance the cursor for a trial by one step.""" |
| 95 | + self._trial_cursors[trial_index] = min( |
| 96 | + self._trial_cursors[trial_index] + self.step_size, 1.0 |
| 97 | + ) |
| 98 | + |
| 99 | + def has_trial_data(self, trial_index: int) -> bool: |
| 100 | + """Check if any replay data exists for a given trial.""" |
| 101 | + return trial_index in self._trial_indices |
| 102 | + |
| 103 | + def is_trial_complete(self, trial_index: int) -> bool: |
| 104 | + """Check if a trial's cursor has reached its maximum progression.""" |
| 105 | + if self.min_prog == self.max_prog: |
| 106 | + return True |
| 107 | + curr_prog = self.min_prog + self._trial_cursors[trial_index] * ( |
| 108 | + self.max_prog - self.min_prog |
| 109 | + ) |
| 110 | + return curr_prog >= self._per_trial_max_prog.get(trial_index, 0.0) |
| 111 | + |
| 112 | + def get_data(self, trial_index: int, metric_signature: str) -> pd.DataFrame: |
| 113 | + """Get replay data for a trial up to the current cursor position. |
| 114 | +
|
| 115 | + Returns a DataFrame filtered to rows where MAP_KEY <= current |
| 116 | + progression value, with original (non-normalized) MAP_KEY values. |
| 117 | + """ |
| 118 | + df = self._data.get((trial_index, metric_signature)) |
| 119 | + if df is None: |
| 120 | + return pd.DataFrame() |
| 121 | + if self.min_prog == self.max_prog: |
| 122 | + return df |
| 123 | + curr_prog = self.min_prog + self._trial_cursors[trial_index] * ( |
| 124 | + self.max_prog - self.min_prog |
| 125 | + ) |
| 126 | + return df[df[MAP_KEY] <= curr_prog] |
| 127 | + |
| 128 | + |
26 | 129 | class MapDataReplayMetric(MapMetric): |
27 | 130 | """A metric for replaying historical map data.""" |
28 | 131 |
|
|
0 commit comments