Skip to content

Commit 5c21bd0

Browse files
authored
fix: Make judge runners non-multi-turn (#185)
1 parent cbe3802 commit 5c21bd0

12 files changed

Lines changed: 280 additions & 12 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_model_runner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,17 @@ class LangChainModelRunner(Runner):
2525
:meth:`run`.
2626
"""
2727

28-
def __init__(self, llm: BaseChatModel, config_messages: Optional[List[LDMessage]] = None):
28+
def __init__(
29+
self,
30+
llm: BaseChatModel,
31+
config_messages: Optional[List[LDMessage]] = None,
32+
multi_turn: bool = True,
33+
):
2934
self._llm = llm
3035
self._chat_history = InMemoryChatMessageHistory(
3136
messages=cast(List[BaseMessage], convert_messages_to_langchain(config_messages or []))
3237
)
38+
self._multi_turn = multi_turn
3339

3440
def get_llm(self) -> BaseChatModel:
3541
"""
@@ -61,7 +67,7 @@ async def run(
6167
else:
6268
result = await self._run_completion(langchain_messages)
6369

64-
if result.metrics.success and result.content:
70+
if result.metrics.success and result.content and self._multi_turn:
6571
self._chat_history.add_user_message(input)
6672
self._chat_history.add_ai_message(result.content)
6773

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,16 @@ def create_agent_graph(
6161
)
6262
return LangGraphAgentGraphRunner(graph_def, tools)
6363

64-
def create_model(self, config: AIConfigKind) -> LangChainModelRunner:
64+
def create_model(self, config: AIConfigKind, multi_turn: bool = True) -> LangChainModelRunner:
6565
"""
6666
Create a configured LangChainModelRunner for the given AI config.
6767
6868
:param config: The LaunchDarkly AI configuration
69+
:param multi_turn: When ``True`` (the default) the runner accumulates
70+
successful exchanges into its conversation history. Pass ``False`` to
71+
keep history fixed at the configured baseline across ``run()`` calls.
6972
:return: LangChainModelRunner ready to invoke the model
7073
"""
7174
llm = create_langchain_model(config)
7275
config_messages = list(getattr(config, 'messages', None) or [])
73-
return LangChainModelRunner(llm, config_messages)
76+
return LangChainModelRunner(llm, config_messages, multi_turn=multi_turn)

packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,41 @@ async def test_accumulates_history_across_successful_calls(self, mock_llm):
281281
assert second_call_messages[1].content == 'First response'
282282
assert second_call_messages[2].content == 'Second question'
283283

284+
@pytest.mark.asyncio
285+
async def test_multi_turn_false_does_not_accumulate_history(self, mock_llm):
286+
"""When multi_turn=False the runner must not append to history on success."""
287+
mock_llm.ainvoke = AsyncMock(side_effect=[
288+
AIMessage(content='First response'),
289+
AIMessage(content='Second response'),
290+
])
291+
provider = LangChainModelRunner(mock_llm, multi_turn=False)
292+
baseline_len = len(provider._chat_history.messages)
293+
294+
await provider.run('First question')
295+
assert len(provider._chat_history.messages) == baseline_len
296+
297+
await provider.run('Second question')
298+
assert len(provider._chat_history.messages) == baseline_len
299+
300+
second_call_messages = mock_llm.ainvoke.call_args_list[1][0][0]
301+
assert len(second_call_messages) == 1
302+
assert second_call_messages[0].content == 'Second question'
303+
304+
@pytest.mark.asyncio
305+
async def test_multi_turn_default_accumulates_history(self, mock_llm):
306+
"""Default behavior (multi_turn omitted) still accumulates history (preserves PR #166)."""
307+
mock_llm.ainvoke = AsyncMock(side_effect=[
308+
AIMessage(content='First response'),
309+
AIMessage(content='Second response'),
310+
])
311+
provider = LangChainModelRunner(mock_llm)
312+
baseline_len = len(provider._chat_history.messages)
313+
314+
await provider.run('First question')
315+
await provider.run('Second question')
316+
317+
assert len(provider._chat_history.messages) == baseline_len + 4
318+
284319
@pytest.mark.asyncio
285320
async def test_does_not_accumulate_history_on_failed_call(self, mock_llm):
286321
"""Should not add to history when the call fails."""

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ def __init__(
2929
model_name: str,
3030
parameters: Dict[str, Any],
3131
config_messages: Optional[List[LDMessage]] = None,
32+
multi_turn: bool = True,
3233
):
3334
self._client = client
3435
self._model_name = model_name
3536
self._parameters = parameters
3637
self._history: List[LDMessage] = list(config_messages or [])
38+
self._multi_turn = multi_turn
3739

3840
async def run(
3941
self,
@@ -58,7 +60,7 @@ async def run(
5860
else:
5961
result = await self._run_completion(messages)
6062

61-
if result.metrics.success and result.content:
63+
if result.metrics.success and result.content and self._multi_turn:
6264
self._history.append(user_message)
6365
self._history.append(LDMessage(role='assistant', content=result.content))
6466

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def create_agent_graph(
8484
from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner
8585
return OpenAIAgentGraphRunner(graph_def, tools)
8686

87-
def create_model(self, config: AIConfigKind) -> OpenAIModelRunner:
87+
def create_model(self, config: AIConfigKind, multi_turn: bool = True) -> OpenAIModelRunner:
8888
"""
8989
Create a configured OpenAIModelRunner for the given AI config.
9090
@@ -93,6 +93,9 @@ def create_model(self, config: AIConfigKind) -> OpenAIModelRunner:
9393
needed; all other fields are passed through from the config.
9494
9595
:param config: The LaunchDarkly AI configuration
96+
:param multi_turn: When ``True`` (the default) the runner accumulates
97+
successful exchanges into its conversation history. Pass ``False`` to
98+
keep history fixed at the configured baseline across ``run()`` calls.
9699
:return: OpenAIModelRunner ready to invoke the model
97100
"""
98101
model_name, parameters = self._extract_model_config(config)
@@ -101,7 +104,9 @@ def create_model(self, config: AIConfigKind) -> OpenAIModelRunner:
101104
if tool_defs:
102105
parameters['tools'] = normalize_tool_types(tool_defs)
103106
config_messages = list(getattr(config, 'messages', None) or [])
104-
return OpenAIModelRunner(self._client, model_name, parameters, config_messages)
107+
return OpenAIModelRunner(
108+
self._client, model_name, parameters, config_messages, multi_turn=multi_turn
109+
)
105110

106111
def get_client(self) -> AsyncOpenAI:
107112
"""

packages/ai-providers/server-ai-openai/tests/test_openai_provider.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,65 @@ def make_response(text: str):
234234
{'role': 'user', 'content': 'Second question'},
235235
]
236236

237+
@pytest.mark.asyncio
238+
async def test_multi_turn_false_does_not_accumulate_history(self, mock_client):
239+
"""When multi_turn=False the runner must not append to history on success."""
240+
def make_response(text: str):
241+
r = MagicMock()
242+
r.context_wrapper = None
243+
r.choices = [MagicMock()]
244+
r.choices[0].message = MagicMock()
245+
r.choices[0].message.content = text
246+
r.usage = None
247+
return r
248+
249+
mock_client.chat = MagicMock()
250+
mock_client.chat.completions = MagicMock()
251+
mock_client.chat.completions.create = AsyncMock(side_effect=[
252+
make_response('First response'),
253+
make_response('Second response'),
254+
])
255+
256+
provider = OpenAIModelRunner(mock_client, 'gpt-4o', {}, multi_turn=False)
257+
baseline_len = len(provider._history)
258+
259+
await provider.run('First question')
260+
assert len(provider._history) == baseline_len
261+
262+
await provider.run('Second question')
263+
assert len(provider._history) == baseline_len
264+
265+
# Each call must see only the configured baseline, never the prior turn.
266+
second_call_messages = mock_client.chat.completions.create.call_args_list[1].kwargs['messages']
267+
assert second_call_messages == [{'role': 'user', 'content': 'Second question'}]
268+
269+
@pytest.mark.asyncio
270+
async def test_multi_turn_default_accumulates_history(self, mock_client):
271+
"""Default behavior (multi_turn omitted) still accumulates history (preserves PR #166)."""
272+
def make_response(text: str):
273+
r = MagicMock()
274+
r.context_wrapper = None
275+
r.choices = [MagicMock()]
276+
r.choices[0].message = MagicMock()
277+
r.choices[0].message.content = text
278+
r.usage = None
279+
return r
280+
281+
mock_client.chat = MagicMock()
282+
mock_client.chat.completions = MagicMock()
283+
mock_client.chat.completions.create = AsyncMock(side_effect=[
284+
make_response('First response'),
285+
make_response('Second response'),
286+
])
287+
288+
provider = OpenAIModelRunner(mock_client, 'gpt-4o', {})
289+
baseline_len = len(provider._history)
290+
291+
await provider.run('First question')
292+
await provider.run('Second question')
293+
294+
assert len(provider._history) == baseline_len + 4
295+
237296
@pytest.mark.asyncio
238297
async def test_does_not_accumulate_history_on_failed_call(self, mock_client):
239298
"""Should not add to history when the call fails."""

packages/sdk/server-ai/src/ldai/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ def _create_judge_instance(
339339
if not judge_config.enabled:
340340
return None
341341

342-
provider = RunnerFactory.create_model(judge_config, default_ai_provider)
342+
provider = RunnerFactory.create_model(
343+
judge_config, default_ai_provider, multi_turn=False
344+
)
343345
if not provider:
344346
return None
345347

packages/sdk/server-ai/src/ldai/judge/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,19 @@ async def evaluate_messages(
132132
"""
133133
Evaluates an AI response from chat messages and response.
134134
135+
The conversation is rendered for the judge by joining each message as
136+
``"{role}: {content}"`` on newlines, preserving who said what so the
137+
judge can distinguish user turns from assistant turns.
138+
135139
:param messages: Array of messages representing the conversation history
136140
:param response: The runner result to be evaluated
137141
:param sampling_ratio: Sampling ratio (0-1) to determine if evaluation should be processed.
138142
When ``None`` (the default), falls back to ``self.sample_rate``.
139143
:return: The result of the judge evaluation.
140144
"""
141-
input_text = '\r\n'.join([msg.content for msg in messages]) if messages else ''
145+
input_text = (
146+
'\n'.join(f'{msg.role}: {msg.content}' for msg in messages) if messages else ''
147+
)
142148
output_text = response.content
143149

144150
return await self.evaluate(input_text, output_text, sampling_ratio)

packages/sdk/server-ai/src/ldai/providers/ai_provider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@ class AIProvider(ABC):
1515
create_model(), create_agent(), and create_agent_graph().
1616
"""
1717

18-
def create_model(self, config: Any) -> Optional[Any]:
18+
def create_model(self, config: Any, multi_turn: bool = True) -> Optional[Any]:
1919
"""
2020
Create a configured model executor for the given AI config.
2121
2222
Default implementation warns. Provider implementations should override this method.
2323
2424
:param config: The LaunchDarkly AI configuration
25+
:param multi_turn: When ``True`` (the default) the returned runner should
26+
accumulate conversation history across successful ``run()`` calls.
27+
When ``False`` each invocation starts from the same baseline history,
28+
which is required for callers that share one runner across
29+
independent invocations (e.g. judges).
2530
:return: Configured model runner instance, or None if unsupported
2631
"""
2732
log.warning('create_model not implemented by this provider')

packages/sdk/server-ai/src/ldai/providers/runner_factory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,25 @@ def _get_providers_to_try(
120120
def create_model(
121121
config: AIConfigKind,
122122
default_ai_provider: Optional[str] = None,
123+
multi_turn: bool = True,
123124
) -> Optional[Runner]:
124125
"""
125126
Create a model executor for the given AI completion config.
126127
127128
:param config: LaunchDarkly AI config (completion or judge)
128129
:param default_ai_provider: Optional provider override ('openai', 'langchain', …)
130+
:param multi_turn: When ``True`` (the default) the returned runner appends
131+
each successful exchange to its history so subsequent ``run()`` calls
132+
include the prior conversation. Set ``False`` for callers that share a
133+
single runner across independent invocations (for example, judges) so
134+
each call starts from the same baseline history.
129135
:return: Configured Runner ready to invoke the model, or None
130136
"""
131137
provider_name = config.provider.name.lower() if config.provider else None
132138
providers = RunnerFactory._get_providers_to_try(default_ai_provider, provider_name)
133-
return RunnerFactory._with_fallback(providers, lambda p: p.create_model(config))
139+
return RunnerFactory._with_fallback(
140+
providers, lambda p: p.create_model(config, multi_turn=multi_turn)
141+
)
134142

135143
@staticmethod
136144
def create_agent(

0 commit comments

Comments
 (0)