Skip to content

Commit e8c6692

Browse files
committed
feat: add token limit handling
1 parent 3bce893 commit e8c6692

4 files changed

Lines changed: 426 additions & 0 deletions

File tree

packages/optimization/src/ldai_optimizer/client.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(self, ldClient: LDAIClient) -> None:
157157
self._last_succeeded_context: Optional[OptimizationContext] = None
158158
self._last_optimization_result_id: Optional[str] = None
159159
self._initial_tool_keys: List[str] = []
160+
self._total_token_usage: int = 0
160161

161162
if os.environ.get("LAUNCHDARKLY_API_KEY"):
162163
self._has_api_key = True
@@ -966,12 +967,14 @@ async def _run_ground_truth_optimization(
966967
on_passing_result=gt_options.on_passing_result,
967968
on_failing_result=gt_options.on_failing_result,
968969
on_status_update=gt_options.on_status_update,
970+
token_limit=gt_options.token_limit,
969971
)
970972
self._options = bridge
971973
self._agent_config = agent_config
972974
self._last_run_succeeded = False
973975
self._last_succeeded_context = None
974976
self._last_optimization_result_id = None
977+
self._total_token_usage = 0
975978
self._initialize_class_members_from_config(agent_config)
976979

977980
# Seed from the first model choice on the first iteration
@@ -1036,6 +1039,26 @@ async def _run_ground_truth_optimization(
10361039
linear_iter,
10371040
expected_response=sample.expected_response,
10381041
)
1042+
self._accumulate_tokens(optimize_context)
1043+
if self._is_token_limit_exceeded():
1044+
logger.error(
1045+
"[GT Attempt %d] -> Token limit exceeded on sample %d (total=%d)",
1046+
attempt,
1047+
i + 1,
1048+
self._total_token_usage,
1049+
)
1050+
attempt_results.append(optimize_context)
1051+
self._last_run_succeeded = False
1052+
self._last_succeeded_context = None
1053+
self._safe_status_update("failure", optimize_context, linear_iter)
1054+
if self._options.on_failing_result:
1055+
try:
1056+
self._options.on_failing_result(optimize_context)
1057+
except Exception:
1058+
logger.exception(
1059+
"[GT Attempt %d] -> on_failing_result callback failed", attempt
1060+
)
1061+
return attempt_results
10391062

10401063
# Per-sample pass/fail check
10411064
if self._options.on_turn is not None:
@@ -1681,6 +1704,7 @@ def _persist_and_forward(
16811704
on_passing_result=options.on_passing_result,
16821705
on_failing_result=options.on_failing_result,
16831706
on_status_update=_persist_and_forward,
1707+
token_limit=config.get("tokenLimit"),
16841708
)
16851709

16861710
variable_choices: List[Dict[str, Any]] = config["variableChoices"] or [{}]
@@ -1700,6 +1724,7 @@ def _persist_and_forward(
17001724
on_passing_result=options.on_passing_result,
17011725
on_failing_result=options.on_failing_result,
17021726
on_status_update=_persist_and_forward,
1727+
token_limit=config.get("tokenLimit"),
17031728
)
17041729

17051730
async def _execute_agent_turn(
@@ -1780,6 +1805,31 @@ async def _execute_agent_turn(
17801805

17811806
return result_ctx
17821807

1808+
def _accumulate_tokens(self, optimize_context: OptimizationContext) -> None:
1809+
"""Add token usage from a completed turn to the running total.
1810+
1811+
Sums the agent's token usage and each judge's token usage from the given
1812+
context and adds them to ``_total_token_usage``.
1813+
1814+
:param optimize_context: The completed turn context containing usage data.
1815+
"""
1816+
if optimize_context.usage is not None:
1817+
self._total_token_usage += optimize_context.usage.total or 0
1818+
for judge_result in optimize_context.scores.values():
1819+
if judge_result.usage is not None:
1820+
self._total_token_usage += judge_result.usage.total or 0
1821+
1822+
def _is_token_limit_exceeded(self) -> bool:
1823+
"""Return True if the accumulated token usage has met or exceeded the configured limit.
1824+
1825+
Returns False when no token limit is set so callers can use this as a
1826+
simple guard without needing to check for ``None`` themselves.
1827+
1828+
:return: True if token limit is set and ``_total_token_usage >= token_limit``.
1829+
"""
1830+
limit: Optional[int] = getattr(self._options, "token_limit", None)
1831+
return limit is not None and self._total_token_usage >= limit
1832+
17831833
def _evaluate_response(self, optimize_context: OptimizationContext) -> bool:
17841834
"""
17851835
Determine whether the current iteration's scores meet all judge thresholds.
@@ -2091,6 +2141,15 @@ async def _run_validation_phase(
20912141
)
20922142
self._safe_status_update("generating", val_ctx, val_iter)
20932143
val_ctx = await self._execute_agent_turn(val_ctx, val_iter)
2144+
self._accumulate_tokens(val_ctx)
2145+
if self._is_token_limit_exceeded():
2146+
logger.error(
2147+
"[Validation %d/%d] -> Token limit exceeded (total=%d)",
2148+
i + 1,
2149+
validation_count,
2150+
self._total_token_usage,
2151+
)
2152+
return False, val_ctx
20942153

20952154
if options.on_turn is not None:
20962155
try:
@@ -2147,6 +2206,7 @@ async def _run_optimization(
21472206
self._last_run_succeeded = False
21482207
self._last_succeeded_context = None
21492208
self._last_optimization_result_id = None
2209+
self._total_token_usage = 0
21502210
self._initialize_class_members_from_config(agent_config)
21512211

21522212
# If the LD flag doesn't carry a model name, seed from the first model choice
@@ -2192,6 +2252,14 @@ async def _run_optimization(
21922252
optimize_context = await self._execute_agent_turn(
21932253
optimize_context, iteration
21942254
)
2255+
self._accumulate_tokens(optimize_context)
2256+
if self._is_token_limit_exceeded():
2257+
logger.error(
2258+
"[Iteration %d] -> Token limit exceeded (total=%d)",
2259+
iteration,
2260+
self._total_token_usage,
2261+
)
2262+
return self._handle_failure(optimize_context, iteration)
21952263

21962264
# Manual path: on_turn callback gives caller full control over pass/fail
21972265
if self._options.on_turn is not None:
@@ -2229,6 +2297,8 @@ async def _run_optimization(
22292297
)
22302298
if all_valid:
22312299
return self._handle_success(optimize_context, iteration)
2300+
if self._is_token_limit_exceeded():
2301+
return self._handle_failure(last_ctx, iteration)
22322302
# Validation failed — treat as a normal failed attempt.
22332303
# Use optimize_context (the main iteration) for terminal API events so
22342304
# the persisted record's completionResponse and userInput stay aligned.

packages/optimization/src/ldai_optimizer/dataclasses.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ class OptimizationOptions:
346346
on_failing_result: Optional[Callable[[OptimizationContext], None]] = None
347347
# called to provide status updates during the optimization flow
348348
on_status_update: Optional[Callable[[_StatusLiteral, OptimizationContext], None]] = None
349+
token_limit: Optional[int] = None # stop the run when total token usage reaches this value
349350

350351
def __post_init__(self):
351352
"""Validate required options."""
@@ -433,6 +434,7 @@ class GroundTruthOptimizationOptions:
433434
project_key: Optional[str] = None # required when auto_commit=True
434435
output_key: Optional[str] = None # variation key/name; auto-generated if omitted
435436
base_url: Optional[str] = None # override to target a non-default LD instance
437+
token_limit: Optional[int] = None # stop the run when total token usage reaches this value
436438

437439
def __post_init__(self):
438440
"""Validate required options."""

packages/optimization/src/ldai_optimizer/ld_api_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class AgentOptimizationConfig(_AgentOptimizationConfigRequired, total=False):
8989

9090
groundTruthResponses: List[str]
9191
metricKey: str
92+
tokenLimit: int
9293

9394

9495
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)