Skip to content

Commit 94de596

Browse files
committed
feat: adds ability to optimize for cost
1 parent 23baeb4 commit 94de596

5 files changed

Lines changed: 691 additions & 9 deletions

File tree

packages/optimization/src/ldai_optimizer/client.py

Lines changed: 146 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
LDApiClient,
5050
)
5151
from ldai_optimizer.prompts import (
52+
_acceptance_criteria_implies_cost_optimization,
5253
_acceptance_criteria_implies_duration_optimization,
5354
build_message_history_text,
5455
build_new_variation_prompt,
@@ -57,6 +58,7 @@
5758
from ldai_optimizer.util import (
5859
RedactionFilter,
5960
await_if_needed,
61+
estimate_cost,
6062
extract_json_from_response,
6163
generate_slug,
6264
interpolate_variables,
@@ -128,6 +130,11 @@ def _compute_validation_count(pool_size: int) -> int:
128130
# under 80% of the baseline — i.e. at least 20% improvement.
129131
_DURATION_TOLERANCE = 0.80
130132

133+
# Cost gate: a candidate must cost at most this fraction of the baseline
134+
# (history[0].estimated_cost_usd) to pass when acceptance criteria imply a
135+
# cost reduction goal. 0.80 means at least 20% cheaper than the baseline.
136+
_COST_TOLERANCE = 0.80
137+
131138
# Maps SDK status strings to the API status/activity values expected by
132139
# agent_optimization_result records. Defined at module level to avoid
133140
# allocating the dict on every on_status_update invocation.
@@ -160,6 +167,7 @@ def __init__(self, ldClient: LDAIClient) -> None:
160167
self._last_optimization_result_id: Optional[str] = None
161168
self._initial_tool_keys: List[str] = []
162169
self._total_token_usage: int = 0
170+
self._model_configs: List[Dict[str, Any]] = []
163171

164172
if os.environ.get("LAUNCHDARKLY_API_KEY"):
165173
self._has_api_key = True
@@ -392,6 +400,7 @@ async def _call_judges(
392400
agent_tools: Optional[List[ToolDefinition]] = None,
393401
expected_response: Optional[str] = None,
394402
agent_duration_ms: Optional[float] = None,
403+
agent_usage: Optional[Any] = None,
395404
) -> Dict[str, JudgeResult]:
396405
"""
397406
Call all judges in parallel (auto-path).
@@ -411,6 +420,8 @@ async def _call_judges(
411420
:param agent_duration_ms: Wall-clock duration of the agent call in milliseconds.
412421
Forwarded to acceptance judges whose statement implies a latency goal so they
413422
can mention the duration change in their rationale.
423+
:param agent_usage: Token usage from the agent call. Forwarded to acceptance judges
424+
whose statement implies a cost goal so they can mention token usage in their rationale.
414425
:return: Dictionary of judge results (score and rationale)
415426
"""
416427
if not self._options.judges:
@@ -464,6 +475,7 @@ async def _call_judges(
464475
agent_tools=resolved_agent_tools,
465476
expected_response=expected_response,
466477
agent_duration_ms=agent_duration_ms,
478+
agent_usage=agent_usage,
467479
)
468480
judge_results[judge_key] = result
469481

@@ -682,6 +694,7 @@ async def _evaluate_acceptance_judge(
682694
agent_tools: Optional[List[ToolDefinition]] = None,
683695
expected_response: Optional[str] = None,
684696
agent_duration_ms: Optional[float] = None,
697+
agent_usage: Optional[Any] = None,
685698
) -> JudgeResult:
686699
"""
687700
Evaluate using an acceptance statement judge.
@@ -699,6 +712,8 @@ async def _evaluate_acceptance_judge(
699712
:param agent_duration_ms: Wall-clock duration of the agent call in milliseconds.
700713
When the acceptance statement implies a latency goal, the judge is instructed
701714
to mention the duration change in its rationale.
715+
:param agent_usage: Token usage from the agent call. When the acceptance statement
716+
implies a cost goal, the judge is instructed to mention token usage and cost.
702717
:return: The judge result with score and rationale
703718
"""
704719
if not optimization_judge.acceptance_statement:
@@ -757,9 +772,64 @@ async def _evaluate_acceptance_judge(
757772
f"This response was {abs(delta_ms):.0f}ms {direction} than the baseline. "
758773
)
759774
instructions += (
760-
"Please mention the duration and any change from baseline in your rationale."
775+
"In your rationale, state the duration and any change from baseline. "
776+
"If the latency goal is not yet met, include specific, actionable suggestions "
777+
"for how the agent's instructions or model choice could be changed to reduce "
778+
"response time — for example: switching to a faster model, shortening the "
779+
"system prompt, or removing instructions that cause multi-step reasoning. "
780+
"These suggestions will be used directly to generate the next variation."
761781
)
762782

783+
if _acceptance_criteria_implies_cost_optimization({judge_key: optimization_judge}):
784+
current_cost = estimate_cost(
785+
agent_usage,
786+
_find_model_config(self._current_model or "", self._model_configs),
787+
)
788+
baseline_cost = (
789+
self._history[0].estimated_cost_usd
790+
if self._history and self._history[0].estimated_cost_usd is not None
791+
else None
792+
)
793+
if current_cost is not None:
794+
has_pricing = (
795+
_find_model_config(self._current_model or "", self._model_configs) or {}
796+
).get("costPerInputToken") is not None
797+
if has_pricing:
798+
cost_str = f"${current_cost:.6f}"
799+
else:
800+
cost_str = f"{int(current_cost)} tokens"
801+
instructions += (
802+
f"\n\nThe acceptance criteria for this judge includes a cost/token-usage goal. "
803+
)
804+
if agent_usage is not None:
805+
instructions += (
806+
f"The agent's response used {agent_usage.input} input tokens "
807+
f"and {agent_usage.output} output tokens "
808+
f"(estimated cost: {cost_str}). "
809+
)
810+
if baseline_cost is not None:
811+
delta = current_cost - baseline_cost
812+
direction = "less" if delta < 0 else "more"
813+
if has_pricing:
814+
baseline_str = f"${baseline_cost:.6f}"
815+
delta_str = f"${abs(delta):.6f}"
816+
else:
817+
baseline_str = f"{int(baseline_cost)} tokens"
818+
delta_str = f"{int(abs(delta))} tokens"
819+
instructions += (
820+
f"The baseline cost (first iteration) was {baseline_str}. "
821+
f"This response cost {delta_str} {direction} than the baseline. "
822+
)
823+
instructions += (
824+
"In your rationale, state the token usage and cost, and any change from baseline. "
825+
"If the cost goal is not yet met, include specific, actionable suggestions "
826+
"for how the agent's instructions or model choice could be changed to reduce "
827+
"cost — for example: switching to a cheaper model, shortening the system prompt "
828+
"to reduce input tokens, removing unnecessary output instructions, or tightening "
829+
"response length constraints. "
830+
"These suggestions will be used directly to generate the next variation."
831+
)
832+
763833
if resolved_variables:
764834
instructions += f"\n\nThe following variables were available to the agent: {json.dumps(resolved_variables)}"
765835

@@ -1082,6 +1152,11 @@ async def _run_ground_truth_optimization(
10821152
):
10831153
sample_passed = self._evaluate_duration(optimize_context)
10841154

1155+
if sample_passed and _acceptance_criteria_implies_cost_optimization(
1156+
self._options.judges
1157+
):
1158+
sample_passed = self._evaluate_cost(optimize_context)
1159+
10851160
if not sample_passed:
10861161
logger.info(
10871162
"[GT Attempt %d] -> Sample %d/%d FAILED",
@@ -1227,12 +1302,19 @@ def _apply_new_variation_response(
12271302
# This is a deterministic safety net for when the LLM ignores the prompt
12281303
# instructions and hardcodes a concrete value (e.g. "user-123") instead
12291304
# of the placeholder ("{{user_id}}").
1305+
# Only check the variables that were actually used for this invocation so
1306+
# we don't spuriously replace values that happen to appear in other choices.
1307+
active_variables = (
1308+
[variation_ctx.current_variables]
1309+
if variation_ctx.current_variables
1310+
else self._options.variable_choices
1311+
)
12301312
self._current_instructions, placeholder_warnings = restore_variable_placeholders(
12311313
self._current_instructions,
1232-
self._options.variable_choices,
1314+
active_variables,
12331315
)
12341316
for msg in placeholder_warnings:
1235-
logger.warning("[Iteration %d] -> %s", iteration, msg)
1317+
logger.debug("[Iteration %d] -> %s", iteration, msg)
12361318

12371319
self._current_parameters = response_data["current_parameters"]
12381320

@@ -1321,6 +1403,9 @@ async def _generate_new_variation(
13211403
optimize_for_duration = _acceptance_criteria_implies_duration_optimization(
13221404
self._options.judges
13231405
)
1406+
optimize_for_cost = _acceptance_criteria_implies_cost_optimization(
1407+
self._options.judges
1408+
)
13241409
instructions = build_new_variation_prompt(
13251410
self._history,
13261411
self._options.judges,
@@ -1331,6 +1416,7 @@ async def _generate_new_variation(
13311416
self._options.variable_choices,
13321417
self._initial_instructions,
13331418
optimize_for_duration=optimize_for_duration,
1419+
optimize_for_cost=optimize_for_cost,
13341420
)
13351421

13361422
# Create a flat history list (without nested history) to avoid exponential growth
@@ -1424,6 +1510,7 @@ async def optimize_from_config(
14241510
model_configs = api_client.get_model_configs(options.project_key)
14251511
except Exception as exc:
14261512
logger.debug("Could not pre-fetch model configs: %s", exc)
1513+
self._model_configs = model_configs
14271514

14281515
context = random.choice(options.context_choices)
14291516
# _get_agent_config calls _initialize_class_members_from_config internally;
@@ -1793,18 +1880,24 @@ async def _execute_agent_turn(
17931880
agent_tools=agent_tools,
17941881
expected_response=expected_response,
17951882
agent_duration_ms=agent_duration_ms,
1883+
agent_usage=agent_response.usage,
17961884
)
17971885

17981886
# Build the fully-populated result context before firing the evaluating event so
17991887
# the PATCH includes scores, generationLatency, and completionResponse. This is
18001888
# particularly important for non-final GT samples which receive no further status
18011889
# events — without this, those fields would never be written to their API records.
1890+
agent_cost = estimate_cost(
1891+
agent_response.usage,
1892+
_find_model_config(self._current_model or "", self._model_configs),
1893+
)
18021894
result_ctx = dataclasses.replace(
18031895
optimize_context,
18041896
completion_response=completion_response,
18051897
scores=scores,
18061898
duration_ms=agent_duration_ms,
18071899
usage=agent_response.usage,
1900+
estimated_cost_usd=agent_cost,
18081901
)
18091902

18101903
if self._options.judges:
@@ -1829,13 +1922,13 @@ def _accumulate_tokens(self, optimize_context: OptimizationContext) -> None:
18291922
def _is_token_limit_exceeded(self) -> bool:
18301923
"""Return True if the accumulated token usage has met or exceeded the configured limit.
18311924
1832-
Returns False when no token limit is set so callers can use this as a
1833-
simple guard without needing to check for ``None`` themselves.
1925+
Returns False when no token limit is set, or when the limit is 0 (which is
1926+
treated as "no limit" — a sentinel value meaning the field was left unset).
18341927
1835-
:return: True if token limit is set and ``_total_token_usage >= token_limit``.
1928+
:return: True if a positive token limit is set and ``_total_token_usage >= token_limit``.
18361929
"""
18371930
limit: Optional[int] = getattr(self._options, "token_limit", None)
1838-
return limit is not None and self._total_token_usage >= limit
1931+
return bool(limit) and self._total_token_usage >= limit
18391932

18401933
def _evaluate_response(self, optimize_context: OptimizationContext) -> bool:
18411934
"""
@@ -1896,6 +1989,42 @@ def _evaluate_duration(self, optimize_context: OptimizationContext) -> bool:
18961989
)
18971990
return passed
18981991

1992+
def _evaluate_cost(self, optimize_context: OptimizationContext) -> bool:
1993+
"""
1994+
Check whether the candidate's estimated cost meets the improvement target vs. the baseline.
1995+
1996+
The baseline is history[0].estimated_cost_usd — the very first completed iteration,
1997+
representing the original unoptimized configuration's cost. The candidate must be
1998+
at least _COST_TOLERANCE cheaper (default: 20% improvement).
1999+
2000+
The cost value is in USD when model pricing data is available, or raw total token
2001+
count as a proxy when pricing is absent. Both are comparable relative to their
2002+
own baselines.
2003+
2004+
Returns True without blocking when no baseline is available (empty history or
2005+
history[0].estimated_cost_usd is None), or when the candidate's cost was not
2006+
captured. This avoids penalising configurations when cost data is missing.
2007+
2008+
:param optimize_context: The completed turn context containing estimated_cost_usd
2009+
:return: True if the cost requirement is met or cannot be checked
2010+
"""
2011+
if not self._history or self._history[0].estimated_cost_usd is None:
2012+
return True
2013+
if optimize_context.estimated_cost_usd is None:
2014+
return True
2015+
baseline = self._history[0].estimated_cost_usd
2016+
passed = optimize_context.estimated_cost_usd < baseline * _COST_TOLERANCE
2017+
if not passed:
2018+
logger.warning(
2019+
"[Iteration %d] -> Cost check failed: %.6f >= baseline %.6f * %.0f%% (%.6f)",
2020+
optimize_context.iteration,
2021+
optimize_context.estimated_cost_usd,
2022+
baseline,
2023+
_COST_TOLERANCE * 100,
2024+
baseline * _COST_TOLERANCE,
2025+
)
2026+
return passed
2027+
18992028
def _handle_success(
19002029
self, optimize_context: OptimizationContext, iteration: int
19012030
) -> Any:
@@ -2174,6 +2303,11 @@ async def _run_validation_phase(
21742303
):
21752304
sample_passed = self._evaluate_duration(val_ctx)
21762305

2306+
if sample_passed and _acceptance_criteria_implies_cost_optimization(
2307+
self._options.judges
2308+
):
2309+
sample_passed = self._evaluate_cost(val_ctx)
2310+
21772311
last_ctx = val_ctx
21782312

21792313
if not sample_passed:
@@ -2298,6 +2432,11 @@ async def _run_optimization(
22982432
):
22992433
initial_passed = self._evaluate_duration(optimize_context)
23002434

2435+
if initial_passed and _acceptance_criteria_implies_cost_optimization(
2436+
self._options.judges
2437+
):
2438+
initial_passed = self._evaluate_cost(optimize_context)
2439+
23012440
if initial_passed:
23022441
all_valid, last_ctx = await self._run_validation_phase(
23032442
optimize_context, iteration

packages/optimization/src/ldai_optimizer/dataclasses.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ class OptimizationContext:
217217
iteration: int = 0 # current iteration number
218218
duration_ms: Optional[float] = None # wall-clock time for the agent call in milliseconds
219219
usage: Optional[TokenUsage] = None # token usage reported by the agent for this iteration
220+
estimated_cost_usd: Optional[float] = None # estimated cost; USD when pricing available, else total tokens
220221

221222
def copy_without_history(self) -> OptimizationContext:
222223
"""
@@ -236,6 +237,7 @@ def copy_without_history(self) -> OptimizationContext:
236237
iteration=self.iteration,
237238
duration_ms=self.duration_ms,
238239
usage=self.usage,
240+
estimated_cost_usd=self.estimated_cost_usd,
239241
)
240242

241243
def to_json(self) -> Dict[str, Any]:
@@ -261,6 +263,7 @@ def to_json(self) -> Dict[str, Any]:
261263
"history": history_list,
262264
"iteration": self.iteration,
263265
"duration_ms": self.duration_ms,
266+
"estimated_cost_usd": self.estimated_cost_usd,
264267
}
265268
if self.usage is not None:
266269
result["usage"] = {

0 commit comments

Comments
 (0)