Skip to content

Commit 7074cfa

Browse files
committed
feat: dx improvements for optimization package
1 parent 8f3468f commit 7074cfa

3 files changed

Lines changed: 58 additions & 48 deletions

File tree

packages/optimization/src/ldai_optimization/client.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
AIJudgeCallConfig,
2020
GroundTruthOptimizationOptions,
2121
GroundTruthSample,
22+
HandleJudgeCall,
2223
JudgeResult,
2324
OptimizationContext,
2425
OptimizationFromConfigOptions,
@@ -228,6 +229,11 @@ def _create_optimization_context(
228229
iteration=iteration,
229230
)
230231

232+
@property
233+
def _judge_call(self) -> HandleJudgeCall:
234+
"""Return the judge callable, falling back to handle_agent_call when not set."""
235+
return self._options.handle_judge_call or self._options.handle_agent_call
236+
231237
def _safe_status_update(
232238
self,
233239
status: Literal[
@@ -569,10 +575,9 @@ async def _evaluate_config_judge(
569575
LDMessage(role="user", content=judge_user_input),
570576
]
571577

572-
# Collect model parameters from the judge config, separating out any existing tools
573-
model_name = (
574-
judge_config.model.name if judge_config.model else self._options.judge_model
575-
)
578+
# Always use the global judge_model; model parameters (temperature, etc.) from
579+
# the judge flag are still forwarded, but the model name is never overridden.
580+
model_name = self._options.judge_model
576581
model_params: Dict[str, Any] = {}
577582
tools: List[ToolDefinition] = []
578583
if judge_config.model and judge_config.model._parameters:
@@ -615,8 +620,8 @@ async def _evaluate_config_judge(
615620
)
616621

617622
_judge_start = time.monotonic()
618-
result = self._options.handle_judge_call(
619-
judge_key, judge_call_config, judge_ctx
623+
result = self._judge_call(
624+
judge_key, judge_call_config, judge_ctx, True
620625
)
621626
judge_response: OptimizationResponse = await await_if_needed(result)
622627
judge_duration_ms = (time.monotonic() - _judge_start) * 1000
@@ -776,8 +781,8 @@ async def _evaluate_acceptance_judge(
776781
)
777782

778783
_judge_start = time.monotonic()
779-
result = self._options.handle_judge_call(
780-
judge_key, judge_call_config, judge_ctx
784+
result = self._judge_call(
785+
judge_key, judge_call_config, judge_ctx, True
781786
)
782787
judge_response: OptimizationResponse = await await_if_needed(result)
783788
judge_duration_ms = (time.monotonic() - _judge_start) * 1000
@@ -1318,6 +1323,7 @@ async def _generate_new_variation(
13181323
self._agent_key,
13191324
agent_config,
13201325
variation_ctx,
1326+
False,
13211327
)
13221328
variation_response: OptimizationResponse = await await_if_needed(result)
13231329
response_str = variation_response.output
@@ -1717,6 +1723,7 @@ async def _execute_agent_turn(
17171723
self._agent_key,
17181724
self._build_agent_config_for_context(optimize_context),
17191725
optimize_context,
1726+
False,
17201727
)
17211728
agent_response: OptimizationResponse = await await_if_needed(result)
17221729
agent_duration_ms = (time.monotonic() - _agent_start) * 1000

packages/optimization/src/ldai_optimization/dataclasses.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,12 @@ class OptimizationJudgeContext:
282282
# the concrete types (AIAgentConfig / AIJudgeCallConfig) continue to work
283283
# because those types structurally satisfy the Protocols.
284284
HandleAgentCall = Union[
285-
Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse],
286-
Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]],
285+
Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse],
286+
Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]],
287287
]
288288
HandleJudgeCall = Union[
289-
Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse],
290-
Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]],
289+
Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse],
290+
Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]],
291291
]
292292

293293
_StatusLiteral = Literal[
@@ -315,7 +315,8 @@ class OptimizationOptions:
315315
] # choices of interpolated variables to be chosen at random per turn, 1 min required
316316
# Actual agent/completion (judge) calls - Required
317317
handle_agent_call: HandleAgentCall
318-
handle_judge_call: HandleJudgeCall
318+
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
319+
handle_judge_call: Optional[HandleJudgeCall] = None
319320
# Criteria for pass/fail - Optional
320321
user_input_options: Optional[List[str]] = (
321322
None # optional list of user input messages to randomly select from
@@ -401,7 +402,8 @@ class GroundTruthOptimizationOptions:
401402
model_choices: List[str]
402403
judge_model: str
403404
handle_agent_call: HandleAgentCall
404-
handle_judge_call: HandleJudgeCall
405+
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
406+
handle_judge_call: Optional[HandleJudgeCall] = None
405407
judges: Optional[Dict[str, OptimizationJudge]] = None
406408
on_turn: Optional[Callable[[OptimizationContext], bool]] = None
407409
on_sample_result: Optional[Callable[[OptimizationContext], None]] = None
@@ -461,7 +463,8 @@ class OptimizationFromConfigOptions:
461463

462464
project_key: str
463465
handle_agent_call: HandleAgentCall
464-
handle_judge_call: HandleJudgeCall
466+
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
467+
handle_judge_call: Optional[HandleJudgeCall] = None
465468
on_turn: Optional[Callable[["OptimizationContext"], bool]] = None
466469
on_sample_result: Optional[Callable[["OptimizationContext"], None]] = None
467470
on_passing_result: Optional[Callable[["OptimizationContext"], None]] = None

0 commit comments

Comments
 (0)