Skip to content

Commit 55674ae

Browse files
committed
chore: improve call config, context so they're passable as a single type, remove required context_choices argument and default to anon
1 parent 31c8385 commit 55674ae

4 files changed

Lines changed: 41 additions & 27 deletions

File tree

packages/optimization/src/ldai_optimization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
GroundTruthOptimizationOptions,
1212
GroundTruthSample,
1313
LLMCallConfig,
14+
LLMCallContext,
1415
OptimizationContext,
1516
OptimizationFromConfigOptions,
1617
OptimizationJudge,
@@ -30,6 +31,7 @@
3031
'GroundTruthSample',
3132
'LDApiError',
3233
'LLMCallConfig',
34+
'LLMCallContext',
3335
'OptimizationClient',
3436
'OptimizationContext',
3537
'OptimizationFromConfigOptions',

packages/optimization/src/ldai_optimization/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ async def _evaluate_config_judge(
611611

612612
judge_ctx = OptimizationJudgeContext(
613613
user_input=judge_user_input,
614-
variables=variables or {},
614+
current_variables=variables or {},
615615
)
616616

617617
_judge_start = time.monotonic()
@@ -772,7 +772,7 @@ async def _evaluate_acceptance_judge(
772772

773773
judge_ctx = OptimizationJudgeContext(
774774
user_input=judge_user_input,
775-
variables=resolved_variables,
775+
current_variables=resolved_variables,
776776
)
777777

778778
_judge_start = time.monotonic()

packages/optimization/src/ldai_optimization/dataclasses.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,11 @@ class LLMCallConfig(Protocol):
118118
async def handle_llm_call(
119119
key: str,
120120
config: LLMCallConfig,
121-
context: Union[OptimizationContext, OptimizationJudgeContext],
121+
context: LLMCallContext,
122122
) -> OptimizationResponse:
123123
model_name = config.model.name if config.model else "gpt-4o"
124124
instructions = config.instructions or ""
125+
tools = config.model.get_parameter("tools") if config.model else []
125126
...
126127
127128
OptimizationOptions(
@@ -136,6 +137,17 @@ async def handle_llm_call(
136137
instructions: Optional[str]
137138

138139

140+
class LLMCallContext(Protocol):
141+
"""Structural protocol satisfied by both ``OptimizationContext`` and ``OptimizationJudgeContext``.
142+
143+
Use alongside ``LLMCallConfig`` when writing a single handler for both
144+
``handle_agent_call`` and ``handle_judge_call``.
145+
"""
146+
147+
user_input: Optional[str]
148+
current_variables: Dict[str, Any]
149+
150+
139151
@dataclass
140152
class AIJudgeCallConfig:
141153
"""
@@ -257,20 +269,25 @@ class OptimizationJudgeContext:
257269
"""Context for a single judge evaluation turn."""
258270

259271
user_input: str # the agent response being evaluated
260-
variables: Dict[str, Any] = field(default_factory=dict) # variable set used during agent generation
272+
current_variables: Dict[str, Any] = field(default_factory=dict) # variable set used during agent generation
261273

262274

263275
# Shared callback type aliases used by both OptimizationOptions and
264276
# OptimizationFromConfigOptions to avoid duplicating the full signatures.
265277
# Placed here so all referenced types (OptimizationContext, AIJudgeCallConfig,
266278
# OptimizationJudgeContext) are already defined above.
279+
#
280+
# Both aliases use the LLMCallConfig / LLMCallContext Protocols so callers can
281+
# write a single handler for both agent and judge calls. Handlers typed with
282+
# the concrete types (AIAgentConfig / AIJudgeCallConfig) continue to work
283+
# because those types structurally satisfy the Protocols.
267284
HandleAgentCall = Union[
268-
Callable[[str, AIAgentConfig, OptimizationContext], OptimizationResponse],
269-
Callable[[str, AIAgentConfig, OptimizationContext], Awaitable[OptimizationResponse]],
285+
Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse],
286+
Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]],
270287
]
271288
HandleJudgeCall = Union[
272-
Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext], OptimizationResponse],
273-
Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext], Awaitable[OptimizationResponse]],
289+
Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse],
290+
Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]],
274291
]
275292

276293
_StatusLiteral = Literal[
@@ -289,9 +306,7 @@ class OptimizationJudgeContext:
289306
class OptimizationOptions:
290307
"""Options for agent optimization."""
291308

292-
# Required
293-
context_choices: List[Context] # choices of contexts to be used, 1 min required
294-
# Configuration
309+
# Configuration - Required
295310
max_attempts: int
296311
model_choices: List[str] # model ids the LLM can choose from, 1 min required
297312
judge_model: str # which model to use as judge; this should remain consistent
@@ -311,6 +326,10 @@ class OptimizationOptions:
311326
on_turn: Optional[Callable[[OptimizationContext], bool]] = (
312327
None # if you want manual control of pass/fail
313328
)
329+
# Context - Optional; defaults to a single anonymous context
330+
context_choices: List[Context] = field(
331+
default_factory=lambda: [Context.builder("anonymous").anonymous(True).build()]
332+
)
314333
# Auto-commit - Optional
315334
auto_commit: bool = False
316335
project_key: Optional[str] = None # required when auto_commit=True
@@ -323,8 +342,6 @@ class OptimizationOptions:
323342

324343
def __post_init__(self):
325344
"""Validate required options."""
326-
if len(self.context_choices) < 1:
327-
raise ValueError("context_choices must have at least 1 context")
328345
if len(self.model_choices) < 1:
329346
raise ValueError("model_choices must have at least 1 model")
330347
if self.judges is None and self.on_turn is None:
@@ -379,7 +396,6 @@ class GroundTruthOptimizationOptions:
379396
:param on_status_update: Called on each status transition during the run.
380397
"""
381398

382-
context_choices: List[Context]
383399
ground_truth_responses: List[GroundTruthSample]
384400
max_attempts: int
385401
model_choices: List[str]
@@ -400,6 +416,10 @@ class GroundTruthOptimizationOptions:
400416
None,
401417
]
402418
] = None
419+
# Context - Optional; defaults to a single anonymous context
420+
context_choices: List[Context] = field(
421+
default_factory=lambda: [Context.builder("anonymous").anonymous(True).build()]
422+
)
403423
# Auto-commit - Optional
404424
auto_commit: bool = False
405425
project_key: Optional[str] = None # required when auto_commit=True
@@ -408,8 +428,6 @@ class GroundTruthOptimizationOptions:
408428

409429
def __post_init__(self):
410430
"""Validate required options."""
411-
if len(self.context_choices) < 1:
412-
raise ValueError("context_choices must have at least 1 context")
413431
if len(self.model_choices) < 1:
414432
raise ValueError("model_choices must have at least 1 model")
415433
if len(self.ground_truth_responses) < 1:
@@ -442,20 +460,18 @@ class OptimizationFromConfigOptions:
442460
"""
443461

444462
project_key: str
445-
context_choices: List[Context]
446463
handle_agent_call: HandleAgentCall
447464
handle_judge_call: HandleJudgeCall
448465
on_turn: Optional[Callable[["OptimizationContext"], bool]] = None
449466
on_sample_result: Optional[Callable[["OptimizationContext"], None]] = None
450467
on_passing_result: Optional[Callable[["OptimizationContext"], None]] = None
451468
on_failing_result: Optional[Callable[["OptimizationContext"], None]] = None
452469
on_status_update: Optional[Callable[[_StatusLiteral, "OptimizationContext"], None]] = None
470+
# Context - Optional; defaults to a single anonymous context
471+
context_choices: List[Context] = field(
472+
default_factory=lambda: [Context.builder("anonymous").anonymous(True).build()]
473+
)
453474
base_url: Optional[str] = None
454475
# Auto-commit defaults to True for config-driven runs; set False to disable
455476
auto_commit: bool = True
456477
output_key: Optional[str] = None # variation key/name; auto-generated if omitted
457-
458-
def __post_init__(self):
459-
"""Validate required options."""
460-
if len(self.context_choices) < 1:
461-
raise ValueError("context_choices must have at least 1 context")

packages/optimization/tests/test_client.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ async def test_variables_in_context(self):
513513
)
514514
call_args = self.handle_judge_call.call_args
515515
_, _, ctx = call_args.args
516-
assert ctx.variables == variables
516+
assert ctx.current_variables == variables
517517

518518
async def test_duration_context_added_to_instructions_when_latency_keyword_present(self):
519519
"""When acceptance statement has a latency keyword and agent_duration_ms is provided,
@@ -2581,10 +2581,6 @@ def test_valid_options_created(self):
25812581
opts = self._make()
25822582
assert len(opts.ground_truth_responses) == 1
25832583

2584-
def test_raises_empty_context_choices(self):
2585-
with pytest.raises(ValueError, match="context_choices"):
2586-
self._make(context_choices=[])
2587-
25882584
def test_raises_empty_model_choices(self):
25892585
with pytest.raises(ValueError, match="model_choices"):
25902586
self._make(model_choices=[])

0 commit comments

Comments
 (0)