Skip to content

Commit 4bb3e78

Browse files
authored
feat: Support conversation history directly in AI Provider model runners (#166)
1 parent e6942a6 commit 4bb3e78

4 files changed

Lines changed: 140 additions & 22 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_model_runner.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional, cast
22

3+
from langchain_core.chat_history import InMemoryChatMessageHistory
34
from langchain_core.language_models.chat_models import BaseChatModel
4-
from langchain_core.messages import BaseMessage
5+
from langchain_core.messages import BaseMessage, HumanMessage
56
from ldai import LDMessage, log
67
from ldai.providers.runner import Runner
78
from ldai.providers.types import LDAIMetrics, RunnerResult
@@ -26,7 +27,9 @@ class LangChainModelRunner(Runner):
2627

2728
def __init__(self, llm: BaseChatModel, config_messages: Optional[List[LDMessage]] = None):
2829
self._llm = llm
29-
self._config_messages: List[LDMessage] = list(config_messages or [])
30+
self._chat_history = InMemoryChatMessageHistory(
31+
messages=cast(List[BaseMessage], convert_messages_to_langchain(config_messages or []))
32+
)
3033

3134
def get_llm(self) -> BaseChatModel:
3235
"""
@@ -44,26 +47,29 @@ async def run(
4447
"""
4548
Run the LangChain model with the given input.
4649
47-
Prepends any config messages (system prompt, instructions, etc.) stored
48-
at construction time before the user message.
49-
5050
:param input: A string prompt
5151
:param output_type: Optional JSON schema dict requesting structured output.
5252
When provided, ``parsed`` on the returned :class:`RunnerResult` is
5353
populated with the parsed JSON document.
5454
:return: :class:`RunnerResult` containing ``content``, ``metrics``,
5555
``raw`` and (when ``output_type`` is set) ``parsed``.
5656
"""
57-
messages = self._config_messages + [LDMessage(role='user', content=input)]
57+
langchain_messages = self._chat_history.messages + [HumanMessage(content=input)]
5858

5959
if output_type is not None:
60-
return await self._run_structured(messages, output_type)
61-
return await self._run_completion(messages)
60+
result = await self._run_structured(langchain_messages, output_type)
61+
else:
62+
result = await self._run_completion(langchain_messages)
63+
64+
if result.metrics.success and result.content:
65+
self._chat_history.add_user_message(input)
66+
self._chat_history.add_ai_message(result.content)
67+
68+
return result
6269

63-
async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult:
70+
async def _run_completion(self, messages: List[BaseMessage]) -> RunnerResult:
6471
try:
65-
langchain_messages = convert_messages_to_langchain(messages)
66-
response: BaseMessage = await self._llm.ainvoke(langchain_messages)
72+
response: BaseMessage = await self._llm.ainvoke(messages)
6773
metrics = get_ai_metrics_from_response(response)
6874

6975
content: str = ''
@@ -90,13 +96,12 @@ async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult:
9096

9197
async def _run_structured(
9298
self,
93-
messages: List[LDMessage],
99+
messages: List[BaseMessage],
94100
output_type: Dict[str, Any],
95101
) -> RunnerResult:
96102
try:
97-
langchain_messages = convert_messages_to_langchain(messages)
98103
structured_llm = self._llm.with_structured_output(output_type, include_raw=True)
99-
response = await structured_llm.ainvoke(langchain_messages)
104+
response = await structured_llm.ainvoke(messages)
100105

101106
if not isinstance(response, dict):
102107
log.warning(f'Structured output did not return a dict. Got: {type(response)}')

packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,59 @@ async def test_returns_success_false_when_model_invocation_throws_error(self, mo
262262
assert result.metrics.success is False
263263
assert result.content == ''
264264

265+
@pytest.mark.asyncio
266+
async def test_accumulates_history_across_successful_calls(self, mock_llm):
267+
"""Should include prior exchange in messages on subsequent calls."""
268+
mock_llm.ainvoke = AsyncMock(side_effect=[
269+
AIMessage(content='First response'),
270+
AIMessage(content='Second response'),
271+
])
272+
provider = LangChainModelRunner(mock_llm)
273+
274+
await provider.run('First question')
275+
await provider.run('Second question')
276+
277+
second_call_messages = mock_llm.ainvoke.call_args_list[1][0][0]
278+
roles = [type(m).__name__ for m in second_call_messages]
279+
assert roles == ['HumanMessage', 'AIMessage', 'HumanMessage']
280+
assert second_call_messages[0].content == 'First question'
281+
assert second_call_messages[1].content == 'First response'
282+
assert second_call_messages[2].content == 'Second question'
283+
284+
@pytest.mark.asyncio
285+
async def test_does_not_accumulate_history_on_failed_call(self, mock_llm):
286+
"""Should not add to history when the call fails."""
287+
mock_llm.ainvoke = AsyncMock(side_effect=Exception('Model error'))
288+
provider = LangChainModelRunner(mock_llm)
289+
290+
await provider.run('Hello')
291+
292+
mock_llm.ainvoke = AsyncMock(return_value=AIMessage(content='Recovery'))
293+
await provider.run('Try again')
294+
295+
second_call_messages = mock_llm.ainvoke.call_args_list[0][0][0]
296+
assert len(second_call_messages) == 1
297+
assert second_call_messages[0].content == 'Try again'
298+
299+
@pytest.mark.asyncio
300+
async def test_prepends_config_messages_before_history(self, mock_llm):
301+
"""Should send config messages before history on every call."""
302+
mock_llm.ainvoke = AsyncMock(side_effect=[
303+
AIMessage(content='Answer 1'),
304+
AIMessage(content='Answer 2'),
305+
])
306+
config_messages = [LDMessage(role='system', content='You are helpful.')]
307+
provider = LangChainModelRunner(mock_llm, config_messages=config_messages)
308+
309+
await provider.run('Q1')
310+
await provider.run('Q2')
311+
312+
second_call_messages = mock_llm.ainvoke.call_args_list[1][0][0]
313+
assert second_call_messages[0].content == 'You are helpful.'
314+
assert second_call_messages[1].content == 'Q1'
315+
assert second_call_messages[2].content == 'Answer 1'
316+
assert second_call_messages[3].content == 'Q2'
317+
265318

266319
class TestRunStructured:
267320
"""Tests for run() with structured output."""

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_model_runner.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
self._client = client
3434
self._model_name = model_name
3535
self._parameters = parameters
36-
self._config_messages: List[LDMessage] = list(config_messages or [])
36+
self._history: List[LDMessage] = list(config_messages or [])
3737

3838
async def run(
3939
self,
@@ -43,21 +43,26 @@ async def run(
4343
"""
4444
Run the OpenAI model with the given input.
4545
46-
Prepends any config messages (system prompt, instructions, etc.) stored
47-
at construction time before the user message.
48-
4946
:param input: A string prompt
5047
:param output_type: Optional JSON schema dict requesting structured output.
5148
When provided, ``parsed`` on the returned :class:`RunnerResult` is
5249
populated with the parsed JSON document.
5350
:return: :class:`RunnerResult` containing ``content``, ``metrics``,
5451
``raw`` and (when ``output_type`` is set) ``parsed``.
5552
"""
56-
messages = self._config_messages + [LDMessage(role='user', content=input)]
53+
user_message = LDMessage(role='user', content=input)
54+
messages = self._history + [user_message]
5755

5856
if output_type is not None:
59-
return await self._run_structured(messages, output_type)
60-
return await self._run_completion(messages)
57+
result = await self._run_structured(messages, output_type)
58+
else:
59+
result = await self._run_completion(messages)
60+
61+
if result.metrics.success and result.content:
62+
self._history.append(user_message)
63+
self._history.append(LDMessage(role='assistant', content=result.content))
64+
65+
return result
6166

6267
async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult:
6368
try:

packages/ai-providers/server-ai-openai/tests/test_openai_provider.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,61 @@ async def test_returns_unsuccessful_response_when_exception_thrown(self, mock_cl
204204
assert result.content == ''
205205
assert result.metrics.success is False
206206

207+
@pytest.mark.asyncio
208+
async def test_accumulates_history_across_successful_calls(self, mock_client):
209+
"""Should include prior exchange in messages on subsequent calls."""
210+
def make_response(text: str):
211+
r = MagicMock()
212+
r.context_wrapper = None
213+
r.choices = [MagicMock()]
214+
r.choices[0].message = MagicMock()
215+
r.choices[0].message.content = text
216+
r.usage = None
217+
return r
218+
219+
mock_client.chat = MagicMock()
220+
mock_client.chat.completions = MagicMock()
221+
mock_client.chat.completions.create = AsyncMock(side_effect=[
222+
make_response('First response'),
223+
make_response('Second response'),
224+
])
225+
226+
provider = OpenAIModelRunner(mock_client, 'gpt-4o', {})
227+
await provider.run('First question')
228+
await provider.run('Second question')
229+
230+
second_call_messages = mock_client.chat.completions.create.call_args_list[1].kwargs['messages']
231+
assert second_call_messages == [
232+
{'role': 'user', 'content': 'First question'},
233+
{'role': 'assistant', 'content': 'First response'},
234+
{'role': 'user', 'content': 'Second question'},
235+
]
236+
237+
@pytest.mark.asyncio
238+
async def test_does_not_accumulate_history_on_failed_call(self, mock_client):
239+
"""Should not add to history when the call fails."""
240+
mock_client.chat = MagicMock()
241+
mock_client.chat.completions = MagicMock()
242+
mock_client.chat.completions.create = AsyncMock(side_effect=Exception('API Error'))
243+
244+
provider = OpenAIModelRunner(mock_client, 'gpt-4o', {})
245+
await provider.run('Hello!')
246+
247+
def make_ok_response():
248+
r = MagicMock()
249+
r.context_wrapper = None
250+
r.choices = [MagicMock()]
251+
r.choices[0].message = MagicMock()
252+
r.choices[0].message.content = 'Recovery'
253+
r.usage = None
254+
return r
255+
256+
mock_client.chat.completions.create = AsyncMock(return_value=make_ok_response())
257+
await provider.run('Try again')
258+
259+
second_call_messages = mock_client.chat.completions.create.call_args.kwargs['messages']
260+
assert second_call_messages == [{'role': 'user', 'content': 'Try again'}]
261+
207262

208263
class TestRunStructured:
209264
"""Tests for the unified run() method (structured-output path)."""

0 commit comments

Comments
 (0)