Skip to content

Commit 53f455f

Browse files
committed
fix: don't only evaluate final input in GT results
1 parent 9bedf9e commit 53f455f

2 files changed

Lines changed: 86 additions & 5 deletions

File tree

packages/optimization/src/ldai_optimizer/client.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,17 +224,51 @@ def _initialize_class_members_from_config(
224224
self._baseline_cost_usd: Optional[float] = None
225225

226226
def _record_baseline(self, ctx: OptimizationContext) -> None:
227-
"""Capture duration/cost baseline from the first iteration appended to history.
227+
"""Capture duration/cost baseline from a single context.
228228
229-
Called once per run (subsequent calls are no-ops once both values are set).
230-
Storing these explicitly lets ``_trim_history`` use a simple tail slice without
231-
needing to preserve ``history[0]`` as an anchor.
229+
Used by the standard (non-GT) optimization loop where each iteration
230+
produces one result. Called once per run (subsequent calls are no-ops
231+
once both values are set). Storing these explicitly lets
232+
``_trim_history`` use a simple tail slice without needing to preserve
233+
``history[0]`` as an anchor.
232234
"""
233235
if self._baseline_duration_ms is None and ctx.duration_ms is not None:
234236
self._baseline_duration_ms = ctx.duration_ms
235237
if self._baseline_cost_usd is None and ctx.estimated_cost_usd is not None:
236238
self._baseline_cost_usd = ctx.estimated_cost_usd
237239

240+
def _record_baseline_from_batch(self, attempt_results: List[OptimizationContext]) -> None:
241+
"""Capture duration/cost baseline as the average across a GT batch.
242+
243+
Used by the GT optimization loop. The first attempt's N samples form
244+
the baseline; averaging them gives a more stable reference than a
245+
single sample and ensures comparisons in subsequent attempts reflect
246+
the typical performance of the original configuration rather than an
247+
outlier measurement.
248+
249+
Called once per run (subsequent calls are no-ops once both values are
250+
set).
251+
252+
:param attempt_results: All completed sample contexts from the first
253+
GT attempt.
254+
"""
255+
if not attempt_results:
256+
return
257+
if self._baseline_duration_ms is None:
258+
durations = [
259+
ctx.duration_ms for ctx in attempt_results if ctx.duration_ms is not None
260+
]
261+
if durations:
262+
self._baseline_duration_ms = sum(durations) / len(durations)
263+
if self._baseline_cost_usd is None:
264+
costs = [
265+
ctx.estimated_cost_usd
266+
for ctx in attempt_results
267+
if ctx.estimated_cost_usd is not None
268+
]
269+
if costs:
270+
self._baseline_cost_usd = sum(costs) / len(costs)
271+
238272
def _build_agent_config_for_context(
239273
self, ctx: OptimizationContext, skip_interpolation: bool = False
240274
) -> AIAgentConfig:
@@ -1272,7 +1306,7 @@ async def _run_ground_truth_optimization(
12721306
# from all of the previous samples, then trim to one full attempt's worth so
12731307
# judge prompts don't grow unboundedly across many failed attempts.
12741308
if attempt_results:
1275-
self._record_baseline(attempt_results[0])
1309+
self._record_baseline_from_batch(attempt_results)
12761310
self._history.extend(attempt_results)
12771311
self._history = _trim_history(self._history, n)
12781312

packages/optimization/tests/test_client.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for OptimizationClient."""
22

3+
import dataclasses
34
import json
45
from typing import Any, Dict
56
from unittest.mock import AsyncMock, MagicMock, patch
@@ -5176,6 +5177,52 @@ def test_skips_gracefully_when_units_differ_across_model_switch(self):
51765177
assert self.client._evaluate_cost(self._ctx(None)) is True
51775178

51785179

5180+
# ---------------------------------------------------------------------------
5181+
# _record_baseline_from_batch
5182+
# ---------------------------------------------------------------------------
5183+
5184+
5185+
class TestRecordBaselineFromBatch:
5186+
def setup_method(self):
5187+
self.client = _make_client()
5188+
self.client._initialize_class_members_from_config(_make_agent_config())
5189+
5190+
def _ctx(self, duration_ms=None, cost=None):
5191+
ctx = self.client._create_optimization_context(iteration=1, variables={})
5192+
return dataclasses.replace(ctx, duration_ms=duration_ms, estimated_cost_usd=cost)
5193+
5194+
def test_averages_duration_across_batch(self):
5195+
results = [self._ctx(duration_ms=1000), self._ctx(duration_ms=2000), self._ctx(duration_ms=3000)]
5196+
self.client._record_baseline_from_batch(results)
5197+
assert self.client._baseline_duration_ms == 2000.0
5198+
5199+
def test_averages_cost_across_batch(self):
5200+
results = [self._ctx(cost=0.01), self._ctx(cost=0.02), self._ctx(cost=0.03)]
5201+
self.client._record_baseline_from_batch(results)
5202+
assert abs(self.client._baseline_cost_usd - 0.02) < 1e-9
5203+
5204+
def test_skips_none_values_in_average(self):
5205+
results = [self._ctx(duration_ms=1000), self._ctx(duration_ms=None), self._ctx(duration_ms=3000)]
5206+
self.client._record_baseline_from_batch(results)
5207+
assert self.client._baseline_duration_ms == 2000.0
5208+
5209+
def test_noop_when_already_set(self):
5210+
self.client._baseline_duration_ms = 999.0
5211+
results = [self._ctx(duration_ms=1000), self._ctx(duration_ms=2000)]
5212+
self.client._record_baseline_from_batch(results)
5213+
assert self.client._baseline_duration_ms == 999.0
5214+
5215+
def test_noop_on_empty_list(self):
5216+
self.client._record_baseline_from_batch([])
5217+
assert self.client._baseline_duration_ms is None
5218+
assert self.client._baseline_cost_usd is None
5219+
5220+
def test_noop_when_all_values_none(self):
5221+
results = [self._ctx(duration_ms=None), self._ctx(duration_ms=None)]
5222+
self.client._record_baseline_from_batch(results)
5223+
assert self.client._baseline_duration_ms is None
5224+
5225+
51795226
# ---------------------------------------------------------------------------
51805227
# _apply_duration_gate
51815228
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)