Skip to content

Commit 572a2aa

Browse files
committed
fix: better handling of params and custom params for optimization
1 parent e8c6692 commit 572a2aa

2 files changed

Lines changed: 262 additions & 1 deletion

File tree

packages/optimization/src/ldai_optimizer/client.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(self, ldClient: LDAIClient) -> None:
157157
self._last_succeeded_context: Optional[OptimizationContext] = None
158158
self._last_optimization_result_id: Optional[str] = None
159159
self._initial_tool_keys: List[str] = []
160+
self._initial_model_custom: Optional[Dict[str, Any]] = None
160161
self._total_token_usage: int = 0
161162

162163
if os.environ.get("LAUNCHDARKLY_API_KEY"):
@@ -861,6 +862,11 @@ async def _get_agent_config(
861862
if isinstance(t, dict) and "key" in t
862863
]
863864

865+
raw_model = raw_variation.get("model")
866+
self._initial_model_custom = (
867+
raw_model.get("custom") if isinstance(raw_model, dict) else None
868+
)
869+
864870
agent_config = dataclasses.replace(
865871
agent_config, instructions=raw_instructions
866872
)
@@ -1231,7 +1237,32 @@ def _apply_new_variation_response(
12311237
for msg in placeholder_warnings:
12321238
logger.warning("[Iteration %d] -> %s", iteration, msg)
12331239

1234-
self._current_parameters = response_data["current_parameters"]
1240+
# Merge the LLM's returned parameters into the existing ones so that custom
1241+
# parameters (e.g. response_format, max_tokens, structured-output config)
1242+
# are preserved even when the LLM omits them from its response.
1243+
original_params = self._current_parameters.copy()
1244+
new_params = response_data["current_parameters"]
1245+
merged_params = {**original_params, **new_params}
1246+
1247+
# Tools must be returned "unchanged" per the variation prompt. Always restore
1248+
# the original tools so that (a) user-defined tools are never silently dropped
1249+
# and (b) internal framework tools (e.g. structured-output tool injected by
1250+
# the agent SDK) cannot leak in from the LLM's response.
1251+
original_tools = original_params.get("tools")
1252+
if original_tools is not None:
1253+
returned_tools = new_params.get("tools")
1254+
if returned_tools is not None and returned_tools != original_tools:
1255+
logger.warning(
1256+
"[Iteration %d] -> LLM returned a modified tools list; restoring "
1257+
"original tools to prevent tool drift or internal-tool leakage. "
1258+
"Original: %s Returned: %s",
1259+
iteration,
1260+
[t.get("name") if isinstance(t, dict) else getattr(t, "name", t) for t in original_tools],
1261+
[t.get("name") if isinstance(t, dict) else getattr(t, "name", t) for t in returned_tools],
1262+
)
1263+
merged_params["tools"] = original_tools
1264+
1265+
self._current_parameters = merged_params
12351266

12361267
# Update model — it should always be provided since it's required in the schema
12371268
model_value = (
@@ -2017,6 +2048,8 @@ def _commit_variation(
20172048
}
20182049
if self._initial_tool_keys:
20192050
payload["toolKeys"] = list(self._initial_tool_keys)
2051+
if self._initial_model_custom:
2052+
payload["model"] = {"custom": self._initial_model_custom}
20202053

20212054
last_exc: Optional[Exception] = None
20222055
for attempt in range(1, 4):

packages/optimization/tests/test_client.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,162 @@ async def test_raises_after_max_retries_exhausted(self):
948948
assert self.handle_agent_call.call_count == 3
949949

950950

951+
# ---------------------------------------------------------------------------
952+
# Parameter persistence across variation generation
953+
# ---------------------------------------------------------------------------
954+
955+
956+
class TestParameterPersistence:
957+
"""Ensure custom parameters are preserved when the LLM generates a new variation."""
958+
959+
def setup_method(self):
960+
self.client = _make_client()
961+
agent_config = _make_agent_config()
962+
self.client._agent_key = "test-agent"
963+
self.client._agent_config = agent_config
964+
self.client._initial_instructions = AGENT_INSTRUCTIONS
965+
self.client._initialize_class_members_from_config(agent_config)
966+
967+
def _set_params(self, params: Dict[str, Any]) -> None:
968+
self.client._current_parameters = params
969+
970+
def _run_variation(self, returned_params: Dict[str, Any]) -> None:
971+
"""Helper: simulate _apply_new_variation_response with a given returned params dict."""
972+
variation_ctx = OptimizationContext(
973+
scores={},
974+
completion_response="",
975+
current_instructions=AGENT_INSTRUCTIONS,
976+
current_parameters={"temperature": 0.1},
977+
current_variables={},
978+
current_model="gpt-4o",
979+
user_input=None,
980+
iteration=1,
981+
)
982+
response_data = {
983+
"current_instructions": "Improved instructions.",
984+
"current_parameters": returned_params,
985+
"model": "gpt-4o",
986+
}
987+
self.client._options = _make_options()
988+
self.client._apply_new_variation_response(response_data, variation_ctx, json.dumps(response_data), 1)
989+
990+
async def test_custom_param_preserved_when_llm_omits_it(self):
991+
"""Parameters not in LLM response should be preserved from the original config."""
992+
self.client._options = _make_options()
993+
self.client._current_parameters = {"temperature": 0.7, "max_tokens": 512, "seed": 42}
994+
self._run_variation({"temperature": 0.5})
995+
assert self.client._current_parameters["max_tokens"] == 512
996+
assert self.client._current_parameters["seed"] == 42
997+
assert self.client._current_parameters["temperature"] == 0.5
998+
999+
async def test_response_format_preserved_when_llm_omits_it(self):
1000+
"""response_format (structured output config) is preserved even if LLM returns only temperature."""
1001+
self.client._options = _make_options()
1002+
self.client._current_parameters = {
1003+
"temperature": 0.7,
1004+
"response_format": {"type": "json_schema", "json_schema": {"name": "output"}},
1005+
}
1006+
self._run_variation({"temperature": 0.5})
1007+
assert self.client._current_parameters["response_format"] == {
1008+
"type": "json_schema",
1009+
"json_schema": {"name": "output"},
1010+
}
1011+
1012+
async def test_empty_returned_params_preserves_all_original_params(self):
1013+
"""If LLM returns {}, all original parameters survive."""
1014+
self.client._options = _make_options()
1015+
self.client._current_parameters = {"temperature": 0.7, "max_tokens": 256}
1016+
self._run_variation({})
1017+
assert self.client._current_parameters["temperature"] == 0.7
1018+
assert self.client._current_parameters["max_tokens"] == 256
1019+
1020+
async def test_llm_explicit_param_override_is_applied(self):
1021+
"""If the LLM explicitly returns a parameter, the new value is used."""
1022+
self.client._options = _make_options()
1023+
self.client._current_parameters = {"temperature": 0.7, "max_tokens": 256}
1024+
self._run_variation({"temperature": 0.3, "max_tokens": 128})
1025+
assert self.client._current_parameters["temperature"] == 0.3
1026+
assert self.client._current_parameters["max_tokens"] == 128
1027+
1028+
async def test_original_tools_always_restored(self):
1029+
"""Tools from the original config are always restored regardless of LLM response."""
1030+
original_tool = {"name": "my-tool", "type": "function", "description": "desc", "parameters": {}}
1031+
self.client._options = _make_options()
1032+
self.client._current_parameters = {"temperature": 0.7, "tools": [original_tool]}
1033+
self._run_variation({"temperature": 0.5, "tools": []})
1034+
assert self.client._current_parameters["tools"] == [original_tool]
1035+
1036+
async def test_internal_tool_leakage_is_blocked(self):
1037+
"""If LLM returns tools including an internal framework tool, original tools are restored."""
1038+
original_tool = {"name": "user-lookup", "type": "function", "description": "Looks up users", "parameters": {}}
1039+
internal_tool = {"name": "FinalAnswer", "type": "function", "description": "internal", "parameters": {}}
1040+
self.client._options = _make_options()
1041+
self.client._current_parameters = {"temperature": 0.7, "tools": [original_tool]}
1042+
self._run_variation({"temperature": 0.5, "tools": [original_tool, internal_tool]})
1043+
result_tools = self.client._current_parameters["tools"]
1044+
assert result_tools == [original_tool]
1045+
assert not any(t.get("name") == "FinalAnswer" for t in result_tools)
1046+
1047+
async def test_internal_tool_leakage_logs_warning(self):
1048+
"""Tool mismatch should emit a warning."""
1049+
original_tool = {"name": "my-tool", "type": "function", "description": "d", "parameters": {}}
1050+
internal_tool = {"name": "structured_output_tool", "type": "function", "description": "internal", "parameters": {}}
1051+
self.client._options = _make_options()
1052+
self.client._current_parameters = {"temperature": 0.7, "tools": [original_tool]}
1053+
with patch("ldai_optimizer.client.logger") as mock_logger:
1054+
self._run_variation({"temperature": 0.5, "tools": [internal_tool]})
1055+
warning_calls = [c for c in mock_logger.warning.call_args_list if "tool" in str(c).lower()]
1056+
assert len(warning_calls) >= 1
1057+
1058+
async def test_no_original_tools_allows_llm_returned_tools(self):
1059+
"""When the original config had no tools, the LLM is free to return tools."""
1060+
new_tool = {"name": "new-tool", "type": "function", "description": "desc", "parameters": {}}
1061+
self.client._options = _make_options()
1062+
self.client._current_parameters = {"temperature": 0.7}
1063+
self._run_variation({"temperature": 0.5, "tools": [new_tool]})
1064+
assert self.client._current_parameters.get("tools") == [new_tool]
1065+
1066+
async def test_params_preserved_across_full_optimization_loop(self):
1067+
"""End-to-end: custom params survive through a full failed-then-succeeded optimization."""
1068+
custom_params_response = json.dumps({
1069+
"current_instructions": "Improved.",
1070+
"current_parameters": {"temperature": 0.3}, # omits max_tokens and response_format
1071+
"model": "gpt-4o",
1072+
})
1073+
agent_config_with_params = _make_agent_config(
1074+
parameters={"temperature": 0.7, "max_tokens": 512, "response_format": {"type": "json_object"}},
1075+
)
1076+
mock_ldai = _make_ldai_client(agent_config=agent_config_with_params)
1077+
mock_ldai._client.variation.return_value = {
1078+
"instructions": AGENT_INSTRUCTIONS,
1079+
}
1080+
agent_responses = [
1081+
OptimizationResponse(output="Bad answer."), # iteration 1: agent
1082+
OptimizationResponse(output=custom_params_response), # iteration 1: variation
1083+
OptimizationResponse(output="Good answer."), # iteration 2: agent
1084+
OptimizationResponse(output="Good answer."), # iteration 2: validation
1085+
]
1086+
handle_agent_call = AsyncMock(side_effect=agent_responses)
1087+
judge_responses = [
1088+
OptimizationResponse(output=JUDGE_FAIL_RESPONSE),
1089+
OptimizationResponse(output=JUDGE_PASS_RESPONSE),
1090+
OptimizationResponse(output=JUDGE_PASS_RESPONSE),
1091+
]
1092+
handle_judge_call = AsyncMock(side_effect=judge_responses)
1093+
client = _make_client(mock_ldai)
1094+
options = _make_options(
1095+
handle_agent_call=handle_agent_call,
1096+
handle_judge_call=handle_judge_call,
1097+
max_attempts=3,
1098+
)
1099+
result = await client.optimize_from_options("test-agent", options)
1100+
assert result.scores["accuracy"].score == 1.0
1101+
# After variation, max_tokens and response_format should still be present
1102+
assert client._current_parameters.get("max_tokens") == 512
1103+
assert client._current_parameters.get("response_format") == {"type": "json_object"}
1104+
assert client._current_parameters.get("temperature") == 0.3 # LLM's update applied
1105+
1106+
9511107
# ---------------------------------------------------------------------------
9521108
# Full optimization loop
9531109
# ---------------------------------------------------------------------------
@@ -4048,6 +4204,48 @@ def test_toolkeys_not_in_payload_when_no_tools(self):
40484204
payload = api_client.create_ai_config_variation.call_args[0][2]
40494205
assert "toolKeys" not in payload
40504206

4207+
# --- model.custom propagation ---
4208+
4209+
def test_model_custom_included_in_payload_when_set(self):
4210+
client = self._make_client()
4211+
client._initial_model_custom = {"myApp": {"debug": True, "region": "us-east-1"}}
4212+
api_client = _make_api_client_for_commit()
4213+
4214+
client._commit_variation(
4215+
_make_winning_context(), project_key="my-project",
4216+
ai_config_key="my-agent", output_key="k", api_client=api_client,
4217+
)
4218+
4219+
payload = api_client.create_ai_config_variation.call_args[0][2]
4220+
assert payload["model"] == {"custom": {"myApp": {"debug": True, "region": "us-east-1"}}}
4221+
4222+
def test_model_not_in_payload_when_model_custom_is_none(self):
4223+
client = self._make_client()
4224+
client._initial_model_custom = None
4225+
api_client = _make_api_client_for_commit()
4226+
4227+
client._commit_variation(
4228+
_make_winning_context(), project_key="my-project",
4229+
ai_config_key="my-agent", output_key="k", api_client=api_client,
4230+
)
4231+
4232+
payload = api_client.create_ai_config_variation.call_args[0][2]
4233+
assert "model" not in payload
4234+
4235+
def test_model_not_in_payload_when_model_custom_is_empty_dict(self):
4236+
"""An empty custom dict is falsy — treated the same as absent."""
4237+
client = self._make_client()
4238+
client._initial_model_custom = {}
4239+
api_client = _make_api_client_for_commit()
4240+
4241+
client._commit_variation(
4242+
_make_winning_context(), project_key="my-project",
4243+
ai_config_key="my-agent", output_key="k", api_client=api_client,
4244+
)
4245+
4246+
payload = api_client.create_ai_config_variation.call_args[0][2]
4247+
assert "model" not in payload
4248+
40514249

40524250
# ---------------------------------------------------------------------------
40534251
# Tool key extraction from raw variation (_get_agent_config)
@@ -4097,6 +4295,36 @@ async def test_skips_tool_entries_without_key(self):
40974295
await client._get_agent_config("test-agent", LD_CONTEXT)
40984296
assert client._initial_tool_keys == ["good-tool"]
40994297

4298+
async def test_extracts_model_custom_from_raw_variation(self):
4299+
raw = {
4300+
"instructions": AGENT_INSTRUCTIONS,
4301+
"model": {"modelName": "gpt-4o", "custom": {"myApp": {"debug": True}}},
4302+
}
4303+
client = self._make_client_with_variation(raw)
4304+
await client._get_agent_config("test-agent", LD_CONTEXT)
4305+
assert client._initial_model_custom == {"myApp": {"debug": True}}
4306+
4307+
async def test_model_custom_is_none_when_variation_has_no_model(self):
4308+
raw = {"instructions": AGENT_INSTRUCTIONS}
4309+
client = self._make_client_with_variation(raw)
4310+
await client._get_agent_config("test-agent", LD_CONTEXT)
4311+
assert client._initial_model_custom is None
4312+
4313+
async def test_model_custom_is_none_when_model_has_no_custom_key(self):
4314+
raw = {
4315+
"instructions": AGENT_INSTRUCTIONS,
4316+
"model": {"modelName": "gpt-4o", "parameters": {"temperature": 0.7}},
4317+
}
4318+
client = self._make_client_with_variation(raw)
4319+
await client._get_agent_config("test-agent", LD_CONTEXT)
4320+
assert client._initial_model_custom is None
4321+
4322+
async def test_model_custom_is_none_when_model_is_not_a_dict(self):
4323+
raw = {"instructions": AGENT_INSTRUCTIONS, "model": "gpt-4o"}
4324+
client = self._make_client_with_variation(raw)
4325+
await client._get_agent_config("test-agent", LD_CONTEXT)
4326+
assert client._initial_model_custom is None
4327+
41004328

41014329
# ---------------------------------------------------------------------------
41024330
# auto_commit in optimize_from_options

0 commit comments

Comments
 (0)