Skip to content

Commit f2f0894

Browse files
committed
fix: address cursor feedback
1 parent 66bc1f0 commit f2f0894

3 files changed

Lines changed: 122 additions & 36 deletions

File tree

packages/optimization/src/ldai_optimizer/client.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def __init__(self, ldClient: LDAIClient) -> None:
212212
self._initial_tool_keys: List[str] = []
213213
self._total_token_usage: int = 0
214214
self._model_configs: List[Dict[str, Any]] = []
215+
self._last_batch_size: int = 1
215216

216217
if os.environ.get("LAUNCHDARKLY_API_KEY"):
217218
self._has_api_key = True
@@ -1123,6 +1124,7 @@ async def _run_ground_truth_optimization(
11231124
self._last_succeeded_context = None
11241125
self._last_optimization_result_id = None
11251126
self._total_token_usage = 0
1127+
self._last_batch_size = 1
11261128
self._initialize_class_members_from_config(agent_config)
11271129

11281130
# Seed from the first model choice on the first iteration
@@ -1331,6 +1333,9 @@ async def _run_ground_truth_optimization(
13311333
self._record_baseline_from_batch(attempt_results)
13321334
self._history.extend(attempt_results)
13331335
self._history = _trim_history(self._history, n)
1336+
# Track batch size so _all_judges_passing checks every sample in this
1337+
# attempt, not just the last one.
1338+
self._last_batch_size = n
13341339

13351340
logger.info(
13361341
"[GT Attempt %d] -> %d/%d samples failed — generating new variation",
@@ -2140,28 +2145,35 @@ def _evaluate_cost(self, optimize_context: OptimizationContext) -> bool:
21402145
return passed
21412146

21422147
def _all_judges_passing(self) -> bool:
2143-
"""Return True if every user-configured judge passed in the most recent history entry.
2148+
"""Return True if every user-configured judge passed in every sample of the most recent batch.
21442149
2145-
Inspects the last context in ``_history`` and checks each score key that
2146-
corresponds to a judge defined in ``_options.judges`` (skipping synthetic gate
2147-
entries whose keys begin with ``_``). Returns False when history is empty or any
2148-
judge score does not meet its threshold.
2150+
In ground-truth mode the last ``_last_batch_size`` entries in ``_history``
2151+
correspond to the samples from the latest attempt. All of them must pass;
2152+
checking only the last entry would incorrectly return True when a middle sample
2153+
failed but the final sample passed.
2154+
2155+
In single-sample (non-GT) mode ``_last_batch_size`` is 1, so only the most
2156+
recent entry is inspected (original behaviour).
2157+
2158+
Synthetic gate entries (keys beginning with ``_``) are skipped.
2159+
Returns False when history is empty or any judge score does not meet its threshold.
21492160
21502161
This is used to decide whether variation generation should preserve the current
2151-
behavior and only optimise for cost, rather than trying to improve quality further.
2162+
behaviour and only optimise for cost, rather than trying to improve quality further.
21522163
"""
21532164
if not self._history or not self._options.judges:
21542165
return False
2155-
recent = self._history[-1]
2156-
if not recent.scores:
2157-
return False
2158-
for key, judge in self._options.judges.items():
2159-
result = recent.scores.get(key)
2160-
if result is None:
2161-
return False
2162-
threshold = judge.threshold if judge.threshold is not None else 1.0
2163-
if not judge_passed(result.score, threshold, judge.is_inverted):
2166+
batch = self._history[-self._last_batch_size:]
2167+
for ctx in batch:
2168+
if not ctx.scores:
21642169
return False
2170+
for key, judge in self._options.judges.items():
2171+
result = ctx.scores.get(key)
2172+
if result is None:
2173+
return False
2174+
threshold = judge.threshold if judge.threshold is not None else 1.0
2175+
if not judge_passed(result.score, threshold, judge.is_inverted):
2176+
return False
21652177
return True
21662178

21672179
def _apply_duration_gate(
@@ -2201,15 +2213,12 @@ def _apply_duration_gate(
22012213
rationale = "Latency gate passed (no baseline)."
22022214
score = 1.0
22032215
else:
2204-
if self._baseline_duration_ms is not None and ctx.duration_ms is not None:
2205-
rationale = (
2206-
f"Latency improvement gate failed: {ctx.duration_ms:.0f}ms did not improve "
2207-
f"by {int((1 - _DURATION_TOLERANCE) * 100)}% vs baseline "
2208-
f"{self._baseline_duration_ms:.0f}ms "
2209-
f"(required < {self._baseline_duration_ms * _DURATION_TOLERANCE:.0f}ms)."
2210-
)
2211-
else:
2212-
rationale = "Latency gate failed (no baseline data)."
2216+
rationale = (
2217+
f"Latency improvement gate failed: {ctx.duration_ms:.0f}ms did not improve "
2218+
f"by {int((1 - _DURATION_TOLERANCE) * 100)}% vs baseline "
2219+
f"{self._baseline_duration_ms:.0f}ms "
2220+
f"(required < {self._baseline_duration_ms * _DURATION_TOLERANCE:.0f}ms)."
2221+
)
22132222
score = 0.0
22142223
ctx = dataclasses.replace(
22152224
ctx,
@@ -2257,15 +2266,12 @@ def _apply_cost_gate(
22572266
rationale = "Cost gate passed (no baseline)."
22582267
score = 1.0
22592268
else:
2260-
if self._baseline_cost_usd is not None and ctx.estimated_cost_usd is not None:
2261-
rationale = (
2262-
f"Cost improvement gate failed: {ctx.estimated_cost_usd:.6f} did not improve "
2263-
f"by {int((1 - _COST_TOLERANCE) * 100)}% vs baseline "
2264-
f"{self._baseline_cost_usd:.6f} "
2265-
f"(required < {self._baseline_cost_usd * _COST_TOLERANCE:.6f})."
2266-
)
2267-
else:
2268-
rationale = "Cost gate failed (no baseline data)."
2269+
rationale = (
2270+
f"Cost improvement gate failed: {ctx.estimated_cost_usd:.6f} did not improve "
2271+
f"by {int((1 - _COST_TOLERANCE) * 100)}% vs baseline "
2272+
f"{self._baseline_cost_usd:.6f} "
2273+
f"(required < {self._baseline_cost_usd * _COST_TOLERANCE:.6f})."
2274+
)
22692275
score = 0.0
22702276
ctx = dataclasses.replace(
22712277
ctx,
@@ -2600,6 +2606,7 @@ async def _run_optimization(
26002606
self._last_succeeded_context = None
26012607
self._last_optimization_result_id = None
26022608
self._total_token_usage = 0
2609+
self._last_batch_size = 1
26032610
self._initialize_class_members_from_config(agent_config)
26042611

26052612
# If the LD flag doesn't carry a model name, seed from the first model choice

packages/optimization/src/ldai_optimizer/util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ def estimate_cost(
336336
337337
:param usage: Token usage from the agent call. When ``None``, returns ``None``.
338338
:param model_config: Model config dict from ``get_model_configs()``, or ``None``.
339-
:return: Estimated cost in USD, or ``None`` if usage or pricing data is absent.
339+
:return: Estimated cost in USD, or ``None`` if usage or pricing data is absent, or if
340+
both ``usage.input`` and ``usage.output`` are ``None`` (no token counts available).
340341
"""
341342
if usage is None:
342343
return None
@@ -348,8 +349,11 @@ def estimate_cost(
348349
return None
349350

350351
cost = 0.0
352+
computed = False
351353
if input_price is not None and usage.input is not None:
352354
cost += usage.input * input_price
355+
computed = True
353356
if output_price is not None and usage.output is not None:
354357
cost += usage.output * output_price
355-
return cost
358+
computed = True
359+
return cost if computed else None

packages/optimization/tests/test_client.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5094,6 +5094,26 @@ def test_zero_usage_with_pricing_returns_zero(self):
50945094
model_config = {"costPerInputToken": 0.001, "costPerOutputToken": 0.002}
50955095
assert estimate_cost(usage, model_config) == pytest.approx(0.0)
50965096

5097+
def test_returns_none_when_both_token_counts_are_none(self):
5098+
# Pricing exists but both input and output are None — no token counts to
5099+
# compute from, so we must return None rather than 0.0 to avoid
5100+
# cost-gate treating unknown cost as zero cost.
5101+
usage = TokenUsage(total=None, input=None, output=None)
5102+
model_config = {"costPerInputToken": 0.001, "costPerOutputToken": 0.002}
5103+
assert estimate_cost(usage, model_config) is None
5104+
5105+
def test_returns_partial_cost_when_only_input_count_is_none(self):
5106+
# Only output count available — should still compute a partial cost.
5107+
usage = TokenUsage(total=40, input=None, output=40)
5108+
model_config = {"costPerInputToken": 0.001, "costPerOutputToken": 0.002}
5109+
assert estimate_cost(usage, model_config) == pytest.approx(40 * 0.002)
5110+
5111+
def test_returns_partial_cost_when_only_output_count_is_none(self):
5112+
# Only input count available — should still compute a partial cost.
5113+
usage = TokenUsage(total=60, input=60, output=None)
5114+
model_config = {"costPerInputToken": 0.001, "costPerOutputToken": 0.002}
5115+
assert estimate_cost(usage, model_config) == pytest.approx(60 * 0.001)
5116+
50975117

50985118
# ---------------------------------------------------------------------------
50995119
# _acceptance_criteria_implies_cost_optimization
@@ -5629,10 +5649,11 @@ def test_gate_scores_do_not_affect_result(self):
56295649
assert self.client._all_judges_passing() is True
56305650

56315651
def test_uses_most_recent_history_entry(self):
5632-
"""Only the last history entry is inspected."""
5652+
"""In non-GT mode (_last_batch_size=1) only the last history entry is inspected."""
56335653
self.client._options = _make_options(judges={
56345654
"accuracy": OptimizationJudge(threshold=0.8, acceptance_statement="accurate"),
56355655
})
5656+
self.client._last_batch_size = 1
56365657
self.client._history = [
56375658
self._ctx_with_scores({"accuracy": JudgeResult(score=0.5, rationale="early fail")}, iteration=1),
56385659
self._ctx_with_scores({"accuracy": JudgeResult(score=1.0, rationale="later pass")}, iteration=2),
@@ -5653,6 +5674,60 @@ def test_inverted_judge_fails_when_score_above_threshold(self):
56535674
self.client._history = [self._ctx_with_scores({"toxicity": JudgeResult(score=0.5, rationale="toxic")})]
56545675
assert self.client._all_judges_passing() is False
56555676

5677+
# --- GT batch tests ---
5678+
5679+
def test_gt_batch_last_sample_passes_but_earlier_fails_returns_false(self):
5680+
"""Core GT bug: if any sample in the batch failed, must return False even if the last passed."""
5681+
self.client._options = _make_options(judges={
5682+
"accuracy": OptimizationJudge(threshold=0.8, acceptance_statement="accurate"),
5683+
})
5684+
self.client._last_batch_size = 3
5685+
self.client._history = [
5686+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.3, rationale="fail")}, iteration=1), # FAILS
5687+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.9, rationale="ok")}, iteration=2),
5688+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.95, rationale="ok")}, iteration=3),
5689+
]
5690+
assert self.client._all_judges_passing() is False
5691+
5692+
def test_gt_batch_all_samples_pass_returns_true(self):
5693+
self.client._options = _make_options(judges={
5694+
"accuracy": OptimizationJudge(threshold=0.8, acceptance_statement="accurate"),
5695+
})
5696+
self.client._last_batch_size = 3
5697+
self.client._history = [
5698+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.85, rationale="ok")}, iteration=1),
5699+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.90, rationale="ok")}, iteration=2),
5700+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.95, rationale="ok")}, iteration=3),
5701+
]
5702+
assert self.client._all_judges_passing() is True
5703+
5704+
def test_gt_batch_middle_sample_fails_returns_false(self):
5705+
self.client._options = _make_options(judges={
5706+
"accuracy": OptimizationJudge(threshold=0.8, acceptance_statement="accurate"),
5707+
})
5708+
self.client._last_batch_size = 3
5709+
self.client._history = [
5710+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.95, rationale="ok")}, iteration=1),
5711+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.20, rationale="fail")}, iteration=2), # FAILS
5712+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.95, rationale="ok")}, iteration=3),
5713+
]
5714+
assert self.client._all_judges_passing() is False
5715+
5716+
def test_gt_batch_size_respected_ignores_older_batches(self):
5717+
"""Entries outside the current batch window should not influence the result."""
5718+
self.client._options = _make_options(judges={
5719+
"accuracy": OptimizationJudge(threshold=0.8, acceptance_statement="accurate"),
5720+
})
5721+
self.client._last_batch_size = 2
5722+
# 4 entries; batch covers last 2; first 2 are stale (from a previous attempt)
5723+
self.client._history = [
5724+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.1, rationale="old fail")}, iteration=1),
5725+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.1, rationale="old fail")}, iteration=2),
5726+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.9, rationale="ok")}, iteration=3),
5727+
self._ctx_with_scores({"accuracy": JudgeResult(score=0.9, rationale="ok")}, iteration=4),
5728+
]
5729+
assert self.client._all_judges_passing() is True
5730+
56565731

56575732
class TestBuildNewVariationPromptCost:
56585733
def _make_history(self) -> list:

0 commit comments

Comments
 (0)