|
5 | 5 | import logging |
6 | 6 | import os |
7 | 7 | import random |
| 8 | +import uuid |
8 | 9 | from typing import Any, Dict, List, Literal, Optional |
9 | 10 |
|
10 | 11 | from ldai import AIAgentConfig, AIJudgeConfig, AIJudgeConfigDefault, LDAIClient |
|
16 | 17 | AutoCommitConfig, |
17 | 18 | JudgeResult, |
18 | 19 | OptimizationContext, |
| 20 | + OptimizationFromConfigOptions, |
19 | 21 | OptimizationJudge, |
20 | 22 | OptimizationJudgeContext, |
21 | 23 | OptimizationOptions, |
22 | 24 | ToolDefinition, |
23 | 25 | ) |
| 26 | +from ldai_optimization.ld_api_client import ( |
| 27 | + AgentOptimizationConfig, |
| 28 | + LDApiClient, |
| 29 | + OptimizationResultPayload, |
| 30 | +) |
24 | 31 | from ldai_optimization.prompts import ( |
25 | 32 | build_message_history_text, |
26 | 33 | build_new_variation_prompt, |
|
38 | 45 |
|
39 | 46 | logger = logging.getLogger(__name__) |
40 | 47 |
|
| 48 | +# Maps SDK status strings to the API status/activity values expected by |
| 49 | +# agent_optimization_result records. Defined at module level to avoid |
| 50 | +# allocating the dict on every on_status_update invocation. |
| 51 | +_OPTIMIZATION_STATUS_MAP: Dict[str, Dict[str, str]] = { |
| 52 | + "init": {"status": "RUNNING", "activity": "PENDING"}, |
| 53 | + "generating": {"status": "RUNNING", "activity": "GENERATING"}, |
| 54 | + "evaluating": {"status": "RUNNING", "activity": "EVALUATING"}, |
| 55 | + "generating variation": {"status": "RUNNING", "activity": "GENERATING_VARIATION"}, |
| 56 | + "turn completed": {"status": "RUNNING", "activity": "COMPLETED"}, |
| 57 | + "success": {"status": "PASSED", "activity": "COMPLETED"}, |
| 58 | + "failure": {"status": "FAILED", "activity": "COMPLETED"}, |
| 59 | +} |
| 60 | + |
41 | 61 |
|
42 | 62 | class OptimizationClient: |
43 | 63 | _options: OptimizationOptions |
@@ -883,21 +903,149 @@ async def _generate_new_variation( |
883 | 903 | ) |
884 | 904 |
|
885 | 905 | async def optimize_from_config( |
886 | | - self, agent_key: str, optimization_config_key: str |
| 906 | + self, optimization_config_key: str, options: OptimizationFromConfigOptions |
887 | 907 | ) -> Any: |
888 | | - """Optimize an agent from a configuration. |
| 908 | + """Optimize an agent using a configuration fetched from the LaunchDarkly API. |
889 | 909 |
|
890 | | - :param agent_key: Identifier of the agent to optimize. |
891 | | - :param optimization_config_key: Identifier of the optimization configuration to use. |
892 | | - :return: Optimization result. |
| 910 | + The agent key, judge configuration, model choices, and other optimization |
| 911 | + parameters are all sourced from the remote agent optimization config. The |
| 912 | + caller only needs to provide the execution callbacks and evaluation contexts. |
| 913 | +
|
| 914 | + Iteration results are automatically persisted to the LaunchDarkly API so |
| 915 | + the UI can display live run progress. |
| 916 | +
|
| 917 | + :param optimization_config_key: Key of the agent optimization config to fetch. |
| 918 | + :param options: User-provided callbacks and evaluation contexts. |
| 919 | + :return: Optimization result (OptimizationContext from the final iteration). |
893 | 920 | """ |
894 | 921 | if not self._has_api_key: |
895 | 922 | raise ValueError( |
896 | 923 | "LAUNCHDARKLY_API_KEY is not set, so optimize_from_config is not available" |
897 | 924 | ) |
898 | 925 |
|
899 | | - self._agent_key = agent_key |
900 | | - raise NotImplementedError |
| 926 | + assert self._api_key is not None |
| 927 | + api_client = LDApiClient( |
| 928 | + self._api_key, |
| 929 | + **({"base_url": options.base_url} if options.base_url else {}), |
| 930 | + ) |
| 931 | + config = api_client.get_agent_optimization(options.project_key, optimization_config_key) |
| 932 | + |
| 933 | + self._agent_key = config["aiConfigKey"] |
| 934 | + optimization_id: str = config["id"] |
| 935 | + run_id = str(uuid.uuid4()) |
| 936 | + |
| 937 | + context = random.choice(options.context_choices) |
| 938 | + # _get_agent_config calls _initialize_class_members_from_config internally; |
| 939 | + # _run_optimization calls it again to reset history before the loop starts. |
| 940 | + agent_config = await self._get_agent_config(self._agent_key, context) |
| 941 | + |
| 942 | + optimization_options = self._build_options_from_config( |
| 943 | + config, options, api_client, optimization_id, run_id |
| 944 | + ) |
| 945 | + return await self._run_optimization(agent_config, optimization_options) |
| 946 | + |
| 947 | + def _build_options_from_config( |
| 948 | + self, |
| 949 | + config: AgentOptimizationConfig, |
| 950 | + options: OptimizationFromConfigOptions, |
| 951 | + api_client: LDApiClient, |
| 952 | + optimization_id: str, |
| 953 | + run_id: str, |
| 954 | + ) -> OptimizationOptions: |
| 955 | + """Map a fetched AgentOptimization config + user options into OptimizationOptions. |
| 956 | +
|
| 957 | + Acceptance statements and judge configs from the API are merged into a single |
| 958 | + judges dict. An on_status_update closure is injected to persist each iteration |
| 959 | + result to the LaunchDarkly API; any user-supplied on_status_update is chained |
| 960 | + after the persistence call. |
| 961 | +
|
| 962 | + :param config: Validated AgentOptimizationConfig from the API. |
| 963 | + :param options: User-provided options from optimize_from_config. |
| 964 | + :param api_client: Initialised LDApiClient for result persistence. |
| 965 | + :param optimization_id: UUID id of the parent agent_optimization record. |
| 966 | + :param run_id: UUID that groups all result records for this run. |
| 967 | + :return: A fully populated OptimizationOptions ready for _run_optimization. |
| 968 | + """ |
| 969 | + judges: Dict[str, OptimizationJudge] = {} |
| 970 | + |
| 971 | + for i, stmt in enumerate(config["acceptanceStatements"]): |
| 972 | + key = f"acceptance-statement-{i}" |
| 973 | + judges[key] = OptimizationJudge( |
| 974 | + threshold=float(stmt.get("threshold", 0.95)), |
| 975 | + acceptance_statement=stmt["statement"], |
| 976 | + ) |
| 977 | + |
| 978 | + for judge in config["judges"]: |
| 979 | + judges[judge["key"]] = OptimizationJudge( |
| 980 | + threshold=float(judge.get("threshold", 0.95)), |
| 981 | + judge_key=judge["key"], |
| 982 | + ) |
| 983 | + |
| 984 | + if not judges and options.on_turn is None: |
| 985 | + raise ValueError( |
| 986 | + "The optimization config has no acceptance statements or judges, " |
| 987 | + "and no on_turn callback was provided. At least one is required." |
| 988 | + ) |
| 989 | + |
| 990 | + variable_choices: List[Dict[str, Any]] = config["variableChoices"] or [{}] |
| 991 | + user_input_options: Optional[List[str]] = config["userInputOptions"] or None |
| 992 | + |
| 993 | + project_key = options.project_key |
| 994 | + config_version: int = config["version"] |
| 995 | + |
| 996 | + def _persist_and_forward( |
| 997 | + status: Literal[ |
| 998 | + "init", |
| 999 | + "generating", |
| 1000 | + "evaluating", |
| 1001 | + "generating variation", |
| 1002 | + "turn completed", |
| 1003 | + "success", |
| 1004 | + "failure", |
| 1005 | + ], |
| 1006 | + ctx: OptimizationContext, |
| 1007 | + ) -> None: |
| 1008 | + # _safe_status_update (the caller) already wraps this entire function in |
| 1009 | + # a try/except, so errors here are caught and logged without aborting the run. |
| 1010 | + mapped = _OPTIMIZATION_STATUS_MAP.get( |
| 1011 | + status, {"status": "RUNNING", "activity": "PENDING"} |
| 1012 | + ) |
| 1013 | + snapshot = ctx.copy_without_history() |
| 1014 | + payload: OptimizationResultPayload = { |
| 1015 | + "run_id": run_id, |
| 1016 | + "config_optimization_version": config_version, |
| 1017 | + "status": mapped["status"], |
| 1018 | + "activity": mapped["activity"], |
| 1019 | + "iteration": snapshot.iteration, |
| 1020 | + "instructions": snapshot.current_instructions, |
| 1021 | + "parameters": snapshot.current_parameters, |
| 1022 | + "completion_response": snapshot.completion_response, |
| 1023 | + "scores": {k: v.to_json() for k, v in snapshot.scores.items()}, |
| 1024 | + "user_input": snapshot.user_input, |
| 1025 | + } |
| 1026 | + api_client.post_agent_optimization_result(project_key, optimization_id, payload) |
| 1027 | + |
| 1028 | + if options.on_status_update: |
| 1029 | + try: |
| 1030 | + options.on_status_update(status, ctx) |
| 1031 | + except Exception: |
| 1032 | + logger.exception("User on_status_update callback failed for status=%s", status) |
| 1033 | + |
| 1034 | + return OptimizationOptions( |
| 1035 | + context_choices=options.context_choices, |
| 1036 | + max_attempts=config["maxAttempts"], |
| 1037 | + model_choices=config["modelChoices"], |
| 1038 | + judge_model=config["judgeModel"], |
| 1039 | + variable_choices=variable_choices, |
| 1040 | + handle_agent_call=options.handle_agent_call, |
| 1041 | + handle_judge_call=options.handle_judge_call, |
| 1042 | + judges=judges or None, |
| 1043 | + user_input_options=user_input_options, |
| 1044 | + on_turn=options.on_turn, |
| 1045 | + on_passing_result=options.on_passing_result, |
| 1046 | + on_failing_result=options.on_failing_result, |
| 1047 | + on_status_update=_persist_and_forward, |
| 1048 | + ) |
901 | 1049 |
|
902 | 1050 | async def _execute_agent_turn( |
903 | 1051 | self, |
|
0 commit comments