-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathdataclasses.py
More file actions
510 lines (433 loc) · 20.7 KB
/
dataclasses.py
File metadata and controls
510 lines (433 loc) · 20.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
"""Dataclasses for the LaunchDarkly AI optimization package."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
from ldai import AIAgentConfig
from ldai.models import LDMessage, ModelConfig
from ldai.tracker import TokenUsage
from ldclient import Context
from typing_extensions import Protocol
@dataclass
class OptimizationResponse:
"""The return value for both ``handle_agent_call`` and ``handle_judge_call`` callbacks.
:param output: The text output produced by the LLM.
:param usage: Optional token usage for this call. Set fields to 0 or omit entirely
if token tracking is not available for the framework being used.
"""
output: str
usage: Optional[TokenUsage] = None
@dataclass
class JudgeResult:
"""Result from a judge evaluation."""
score: float
rationale: Optional[str] = None
duration_ms: Optional[float] = None
usage: Optional[TokenUsage] = None
estimated_cost_usd: Optional[float] = None
def to_json(self) -> Dict[str, Any]:
"""
Convert the judge result to a JSON-serializable dictionary.
:return: Dictionary representation of the judge result that can be serialized with json.dumps()
"""
result: Dict[str, Any] = {
"score": self.score,
"rationale": self.rationale,
"duration_ms": self.duration_ms,
}
if self.usage is not None:
result["usage"] = {
"total": self.usage.total,
"input": self.usage.input,
"output": self.usage.output,
}
if self.estimated_cost_usd is not None:
result["estimated_cost_usd"] = self.estimated_cost_usd
return result
@dataclass
class ToolDefinition:
"""
Generic tool definition for enforcing structured output from LLM responses.
This tool can be used with any LLM provider to ensure responses conform to
a specific JSON schema. The tool takes the LLM's response and returns
parsed and validated data according to the input_schema.
"""
name: str
description: str
input_schema: Dict[str, Any] # JSON schema defining the expected output structure
type: Literal["function"] = "function"
def to_dict(self) -> Dict[str, Any]:
"""
Convert the tool definition to a dictionary format compatible with LLM APIs.
:return: Dictionary representation of the tool
"""
return {
"name": self.name,
"description": self.description,
"input_schema": self.input_schema,
"type": self.type,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolDefinition":
"""
Construct a ToolDefinition from a plain dictionary.
:param data: Dictionary with at least a ``name`` key; ``description`` and
``input_schema`` default to empty values when absent.
:return: A new ToolDefinition instance
"""
return cls(
name=data.get("name", ""),
description=data.get("description", ""),
input_schema=data.get("input_schema", {}),
type=data.get("type", "function"),
)
class LLMCallConfig(Protocol):
"""Structural protocol satisfied by both ``AIAgentConfig`` and ``AIJudgeCallConfig``.
Use this as the config parameter type when you want a single handler function
that can be passed to both ``handle_agent_call`` and ``handle_judge_call``::
async def handle_llm_call(
key: str,
config: LLMCallConfig,
context: LLMCallContext,
is_evaluation: bool,
) -> OptimizationResponse:
model_name = config.model.name if config.model else "gpt-4o"
instructions = config.instructions or ""
tools = config.model.get_parameter("tools") if config.model else []
...
OptimizationOptions(
handle_agent_call=handle_llm_call,
handle_judge_call=handle_llm_call,
...
)
"""
@property
def key(self) -> str: ...
@property
def model(self) -> Optional[ModelConfig]: ...
@property
def instructions(self) -> Optional[str]: ...
class LLMCallContext(Protocol):
"""Structural protocol satisfied by both ``OptimizationContext`` and ``OptimizationJudgeContext``.
Use alongside ``LLMCallConfig`` when writing a single handler for both
``handle_agent_call`` and ``handle_judge_call``.
"""
@property
def user_input(self) -> Optional[str]: ...
@property
def current_variables(self) -> Dict[str, Any]: ...
@dataclass
class AIJudgeCallConfig:
"""
Configuration passed to ``handle_judge_call``.
Carries everything needed to run a judge in either paradigm:
* **Completions path** — pass ``messages`` directly to ``chat.completions.create``.
The full system + user turn sequence is already assembled and interpolated.
* **Agents path** — use ``instructions`` as the system prompt and
``OptimizationJudgeContext.user_input`` as the ``Runner.run`` input.
Both fields are always populated, regardless of whether the judge comes from a
LaunchDarkly flag (config judge) or an inline acceptance statement.
"""
key: str
model: ModelConfig
instructions: str
messages: List[LDMessage]
@dataclass
class Message:
"""A message in a conversation."""
role: Literal["system", "user", "assistant"]
content: str
def to_dict(self) -> Dict[str, str]:
"""Convert message to dictionary format."""
return {
"role": self.role,
"content": self.content,
}
@dataclass
class OptimizationJudge:
threshold: float
judge_key: Optional[str] = None
acceptance_statement: Optional[str] = None
is_inverted: bool = False
@dataclass
class OptimizationContext:
"""Context for a single optimization iteration."""
scores: Dict[str, JudgeResult] # the scores and rationales from the judges, if configured
completion_response: str
current_instructions: str
current_parameters: Dict[str, Any]
# variable set chosen for this iteration; interpolated into instructions at call time
current_variables: Dict[str, Any]
current_model: Optional[str] = None # the current model being used
user_input: Optional[str] = None # the user input message for this iteration
history: Sequence[OptimizationContext] = field(
default_factory=list
) # previous context items
iteration: int = 0 # current iteration number
duration_ms: Optional[float] = None # wall-clock time for the agent call in milliseconds
usage: Optional[TokenUsage] = None # token usage reported by the agent for this iteration
estimated_cost_usd: Optional[float] = None # estimated cost; USD when pricing available, else total tokens
# single running total across ALL calls in this run (generation + judges + variation)
accumulated_token_usage: Optional[int] = None
def copy_without_history(self) -> OptimizationContext:
"""
Create a copy of this context without the history field (for flattening).
:return: A new OptimizeContext with the same data but empty history
"""
return OptimizationContext(
scores=self.scores,
completion_response=self.completion_response,
current_instructions=self.current_instructions,
current_parameters=self.current_parameters,
current_variables=self.current_variables,
current_model=self.current_model,
user_input=self.user_input,
history=(), # Empty history to keep it flat
iteration=self.iteration,
duration_ms=self.duration_ms,
usage=self.usage,
estimated_cost_usd=self.estimated_cost_usd,
accumulated_token_usage=self.accumulated_token_usage,
)
def to_json(self) -> Dict[str, Any]:
"""
Convert the optimization context to a JSON-serializable dictionary.
:return: Dictionary representation of the context that can be serialized with json.dumps()
"""
scores_dict = {}
for judge_key, judge_result in self.scores.items():
scores_dict[judge_key] = judge_result.to_json()
history_list = [ctx.to_json() for ctx in self.history]
result: Dict[str, Any] = {
"scores": scores_dict,
"completion_response": self.completion_response,
"current_instructions": self.current_instructions,
"current_parameters": self.current_parameters,
"current_model": self.current_model,
"user_input": self.user_input,
"current_variables": self.current_variables,
"history": history_list,
"iteration": self.iteration,
"duration_ms": self.duration_ms,
"estimated_cost_usd": self.estimated_cost_usd,
"accumulated_token_usage": self.accumulated_token_usage,
}
if self.usage is not None:
result["usage"] = {
"total": self.usage.total,
"input": self.usage.input,
"output": self.usage.output,
}
return result
@dataclass
class OptimizationJudgeContext:
"""Context for a single judge evaluation turn."""
user_input: str # the agent response being evaluated
current_variables: Dict[str, Any] = field(default_factory=dict) # variable set used during agent generation
# Shared callback type aliases used by both OptimizationOptions and
# OptimizationFromConfigOptions to avoid duplicating the full signatures.
# Placed here so all referenced types (OptimizationContext, AIJudgeCallConfig,
# OptimizationJudgeContext) are already defined above.
#
# Both aliases use the LLMCallConfig / LLMCallContext Protocols so callers can
# write a single handler for both agent and judge calls. Handlers typed with
# the concrete types (AIAgentConfig / AIJudgeCallConfig) continue to work
# because those types structurally satisfy the Protocols.
HandleAgentCall = Union[
Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]],
]
HandleJudgeCall = Union[
Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]],
]
_StatusLiteral = Literal[
"init",
"generating",
"evaluating",
"generating variation",
"validating",
"turn completed",
"success",
"failure",
"optimizing cost/latency",
]
@dataclass
class OptimizationOptions:
"""Options for agent optimization."""
# Configuration - Required
max_attempts: int
model_choices: List[str] # model ids the LLM can choose from, 1 min required
judge_model: str # which model to use as judge; this should remain consistent
variable_choices: List[
Dict[str, Any]
] # choices of interpolated variables to be chosen at random per turn, 1 min required
# Actual agent/completion (judge) calls - Required
handle_agent_call: HandleAgentCall
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
handle_judge_call: Optional[HandleJudgeCall] = None
# Criteria for pass/fail - Optional
user_input_options: Optional[List[str]] = (
None # optional list of user input messages to randomly select from
)
judges: Optional[Dict[str, OptimizationJudge]] = (
None # auto-judges for this model that the LLM will use
)
on_turn: Optional[Callable[[OptimizationContext], bool]] = (
None # if you want manual control of pass/fail
)
# Context - Optional; defaults to a single anonymous context
context_choices: List[Context] = field(
default_factory=lambda: [Context.builder("anonymous").anonymous(True).build()]
)
# Base variation - Optional
variation_key: Optional[str] = None # use this specific variation as the base; defaults to the flag's default variation; requires API key + project_key
# Optimization controls - Optional; when None the corresponding gate/prompt is disabled
latency_optimization: Optional[bool] = None
token_optimization: Optional[bool] = None
# Auto-commit - Optional
auto_commit: bool = False
project_key: Optional[str] = None # required when auto_commit=True or variation_key is set
output_key: Optional[str] = None # variation key/name; auto-generated if omitted
base_url: Optional[str] = None # override to target a non-default LD instance
on_passing_result: Optional[Callable[[OptimizationContext], None]] = None
on_failing_result: Optional[Callable[[OptimizationContext], None]] = None
# called to provide status updates during the optimization flow
on_status_update: Optional[Callable[[_StatusLiteral, OptimizationContext], None]] = None
token_limit: Optional[int] = None # stop the run when total token usage reaches this value
def __post_init__(self):
"""Validate required options."""
if len(self.model_choices) < 1:
raise ValueError("model_choices must have at least 1 model")
if self.judges is None and self.on_turn is None:
raise ValueError("Either judges or on_turn must be provided")
if self.judge_model is None:
raise ValueError("judge_model must be provided")
@dataclass
class GroundTruthSample:
"""A single ground truth evaluation sample for use with optimize_from_ground_truth_options.
Each sample ties together the user input, expected response, and variable set for one
evaluation. Samples are evaluated in order; the optimization only passes if all samples
pass their judges in the same attempt.
:param user_input: The user message to send to the agent for this evaluation.
:param expected_response: The ideal response the agent should produce. Injected into
judge context so judges can score actual vs. expected.
:param variables: Variable set interpolated into the agent instructions for this sample.
Defaults to an empty dict if no placeholders are used.
"""
user_input: str
expected_response: str
variables: Dict[str, Any] = field(default_factory=dict)
@dataclass
class GroundTruthOptimizationOptions:
"""Options for optimize_from_ground_truth_options.
Mirrors OptimizationOptions but replaces variable_choices / user_input_options with
ground_truth_responses. Each GroundTruthSample bundles the user input, expected
response, and variable set for one evaluation. All N samples must pass their judges
in the same attempt for the optimization to succeed.
:param context_choices: One or more LD evaluation contexts to use.
:param ground_truth_responses: Ordered list of ground truth samples to evaluate.
At least 1 required. All samples share the same instructions and model being optimized.
:param max_attempts: Maximum number of variation attempts before the run is marked failed.
:param model_choices: Model IDs the variation generator may select from. At least 1 required.
:param judge_model: Model used for judge evaluation. Should remain consistent across attempts.
:param handle_agent_call: Callback that invokes the agent and returns its response.
:param handle_judge_call: Callback that invokes a judge LLM and returns its response.
:param judges: Auto-judges (config judges and/or acceptance statements) to score each response.
:param on_turn: Optional manual pass/fail callback applied per sample; skips judge scoring when provided.
:param on_sample_result: Called with each sample's OptimizationContext as results arrive,
before the overall pass/fail decision is made for the attempt.
:param on_passing_result: Called once with the last context when all N samples pass.
:param on_failing_result: Called once with the last context when max attempts are exhausted.
:param on_status_update: Called on each status transition during the run.
"""
ground_truth_responses: List[GroundTruthSample]
max_attempts: int
model_choices: List[str]
judge_model: str
handle_agent_call: HandleAgentCall
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
handle_judge_call: Optional[HandleJudgeCall] = None
judges: Optional[Dict[str, OptimizationJudge]] = None
on_turn: Optional[Callable[[OptimizationContext], bool]] = None
on_sample_result: Optional[Callable[[OptimizationContext], None]] = None
on_passing_result: Optional[Callable[[OptimizationContext], None]] = None
on_failing_result: Optional[Callable[[OptimizationContext], None]] = None
on_status_update: Optional[
Callable[
[
_StatusLiteral,
OptimizationContext,
],
None,
]
] = None
# Context - Optional; defaults to a single anonymous context
context_choices: List[Context] = field(
default_factory=lambda: [Context.builder("anonymous").anonymous(True).build()]
)
# Base variation - Optional
variation_key: Optional[str] = None # use this specific variation as the base; defaults to the flag's default variation; requires API key + project_key
# Optimization controls - Optional; when None the corresponding gate/prompt is disabled
latency_optimization: Optional[bool] = None
token_optimization: Optional[bool] = None
# Auto-commit - Optional
auto_commit: bool = False
project_key: Optional[str] = None # required when auto_commit=True or variation_key is set
output_key: Optional[str] = None # variation key/name; auto-generated if omitted
base_url: Optional[str] = None # override to target a non-default LD instance
token_limit: Optional[int] = None # stop the run when total token usage reaches this value
def __post_init__(self):
"""Validate required options."""
if len(self.model_choices) < 1:
raise ValueError("model_choices must have at least 1 model")
if len(self.ground_truth_responses) < 1:
raise ValueError("ground_truth_responses must have at least 1 sample")
if self.judges is None and self.on_turn is None:
raise ValueError("Either judges or on_turn must be provided")
@dataclass
class OptimizationFromConfigOptions:
"""User-provided options for optimize_from_config.
Fields that come from the LaunchDarkly API (max_attempts, model_choices,
judge_model, variable_choices, user_input_options, judges) are omitted here
and sourced from the fetched agent optimization config instead.
:param project_key: LaunchDarkly project key used to build API paths.
:param context_choices: One or more LD evaluation contexts to use.
:param handle_agent_call: Callback that invokes the agent and returns its response.
:param handle_judge_call: Callback that invokes a judge and returns its response.
:param on_turn: Optional manual pass/fail callback; when provided, judge scoring is skipped.
:param on_sample_result: Ground truth path only. Called with each sample's
OptimizationContext as results arrive during a ground truth run.
:param on_passing_result: Called with the winning OptimizationContext on success.
:param on_failing_result: Called with the final OptimizationContext on failure.
:param on_status_update: Called on each status transition; chained after the
automatic result-persistence POST so it always runs after the record is saved.
:param base_url: Base URL of the LaunchDarkly instance. Defaults to
https://app.launchdarkly.com. Override to target a staging instance.
"""
project_key: str
handle_agent_call: HandleAgentCall
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
handle_judge_call: Optional[HandleJudgeCall] = None
on_turn: Optional[Callable[["OptimizationContext"], bool]] = None
on_sample_result: Optional[Callable[["OptimizationContext"], None]] = None
on_passing_result: Optional[Callable[["OptimizationContext"], None]] = None
on_failing_result: Optional[Callable[["OptimizationContext"], None]] = None
on_status_update: Optional[Callable[[_StatusLiteral, "OptimizationContext"], None]] = None
# Context - Optional; defaults to a single anonymous context
context_choices: List[Context] = field(
default_factory=lambda: [Context.builder("anonymous").anonymous(True).build()]
)
base_url: Optional[str] = None
# Auto-commit defaults to True for config-driven runs; set False to disable
auto_commit: bool = True
output_key: Optional[str] = None # variation key/name; auto-generated if omitted