Skip to content

Commit 2d6216c

Browse files
committed
fix(middleware): fix ConversationMemoryMiddleware not storing or retrieving messages
The middleware was silently failing in real agent scenarios because: 1. Response content extraction used .content on ModelResponse, but ModelResponse has .result (list[BaseMessage]), not .content. This caused assistant_content to always be empty, so nothing was stored. 2. Context messages were injected as plain dicts, but ModelRequest.messages expects LangChain message types (AnyMessage). Now returns SystemMessage, HumanMessage, and AIMessage objects. 3. Assistant messages were stored with role "assistant", but redisvl's SemanticMessageHistory expects "llm" for assistant messages. Also updated all tests to use ModelResponse(result=[AIMessage(...)]) instead of plain dicts, which previously masked these bugs.
1 parent 4d6c0c9 commit 2d6216c

4 files changed

Lines changed: 173 additions & 39 deletions

File tree

langgraph/middleware/redis/conversation_memory.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
ModelRequest,
1414
ModelResponse,
1515
)
16+
from langchain_core.messages import (
17+
AIMessage,
18+
HumanMessage,
19+
SystemMessage,
20+
)
1621
from langchain_core.messages import ToolMessage as LangChainToolMessage
1722
from langgraph.prebuilt.tool_node import ToolCallRequest
1823
from langgraph.types import Command
@@ -124,25 +129,27 @@ def _extract_query(self, messages: List[Union[dict[str, Any], Any]]) -> str:
124129

125130
return ""
126131

127-
def _format_context_messages(
128-
self, context: List[Dict[str, Any]]
129-
) -> List[Dict[str, str]]:
132+
def _format_context_messages(self, context: List[Dict[str, Any]]) -> List[Any]:
130133
"""Format retrieved context messages for injection.
131134
132135
Args:
133136
context: List of retrieved context messages.
134137
135138
Returns:
136-
Formatted messages ready for injection.
139+
Formatted LangChain message objects ready for injection.
137140
"""
138-
formatted = []
141+
formatted: List[Any] = []
139142
for msg in context:
140-
formatted.append(
141-
{
142-
"role": msg.get("role", "user"),
143-
"content": msg.get("content", ""),
144-
}
145-
)
143+
role = msg.get("role", "user")
144+
content = msg.get("content", "")
145+
if role == "system":
146+
formatted.append(SystemMessage(content=content))
147+
elif role in ("user", "human"):
148+
formatted.append(HumanMessage(content=content))
149+
elif role in ("llm", "ai", "assistant"):
150+
formatted.append(AIMessage(content=content))
151+
else:
152+
formatted.append(HumanMessage(content=content))
146153
return formatted
147154

148155
async def awrap_model_call(
@@ -193,10 +200,9 @@ async def awrap_model_call(
193200
formatted_context = self._format_context_messages(context_messages)
194201
# Insert context before the current messages
195202
# We add them as a context block
196-
context_note = {
197-
"role": "system",
198-
"content": "Relevant context from previous conversations:",
199-
}
203+
context_note = SystemMessage(
204+
content="Relevant context from previous conversations:"
205+
)
200206
enhanced_messages = [context_note] + formatted_context + list(messages)
201207
# Support both dict-style and LangChain ModelRequest types
202208
if isinstance(request, dict):
@@ -211,8 +217,15 @@ async def awrap_model_call(
211217
try:
212218
# Get the user message
213219
user_content = query
214-
# Get the assistant response (support both dict and LangChain types)
215-
if isinstance(response, dict):
220+
# Get the assistant response (support ModelResponse, dict, and
221+
# other LangChain types)
222+
if hasattr(response, "result") and isinstance(response.result, list):
223+
# ModelResponse: result is list[BaseMessage]
224+
if response.result:
225+
assistant_content = getattr(response.result[-1], "content", "")
226+
else:
227+
assistant_content = ""
228+
elif isinstance(response, dict):
216229
assistant_content = response.get("content", "")
217230
else:
218231
assistant_content = getattr(response, "content", "")
@@ -226,7 +239,7 @@ async def awrap_model_call(
226239
if assistant_content:
227240
self._history.add_messages(
228241
[
229-
{"role": "assistant", "content": assistant_content},
242+
{"role": "llm", "content": assistant_content},
230243
]
231244
)
232245
except Exception as e:

tests/integration/test_middleware_end_to_end.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
"""
55

66
import pytest
7+
from langchain.agents.middleware.types import ModelResponse
8+
from langchain_core.messages import AIMessage
79
from testcontainers.redis import RedisContainer
810

911
from langgraph.middleware.redis import (
@@ -132,28 +134,34 @@ async def test_multi_turn_with_memory(self, redis_url: str) -> None:
132134

133135
async with ConversationMemoryMiddleware(config) as middleware:
134136
# Simulate multi-turn conversation
135-
async def mock_llm(request: dict) -> dict:
137+
async def mock_llm(request: dict) -> ModelResponse:
136138
messages = request.get("messages", [])
137139
user_msg = ""
138140
for m in reversed(messages):
139141
if isinstance(m, dict) and m.get("role") == "user":
140142
user_msg = m.get("content", "")
141143
break
144+
elif hasattr(m, "type") and m.type == "human":
145+
user_msg = m.content
146+
break
142147

143-
return {"content": f"I received: {user_msg}"}
148+
return ModelResponse(
149+
result=[AIMessage(content=f"I received: {user_msg}")]
150+
)
144151

145152
# Turn 1
146153
request1 = {"messages": [{"role": "user", "content": "Hello, I'm Alice."}]}
147154
response1 = await middleware.awrap_model_call(request1, mock_llm)
148-
assert "Alice" in response1["content"]
155+
assert "Alice" in response1.result[0].content
149156

150157
# Turn 2
151158
request2 = {
152159
"messages": [{"role": "user", "content": "What's my name again?"}]
153160
}
154161
response2 = await middleware.awrap_model_call(request2, mock_llm)
155-
# The middleware should have injected context
156-
assert "content" in response2
162+
# The middleware should have injected context and returned a ModelResponse
163+
assert hasattr(response2, "result")
164+
assert len(response2.result) > 0
157165

158166

159167
@requires_sentence_transformers

tests/integration/test_middleware_notebook_scenarios.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from unittest.mock import MagicMock
88

99
import pytest
10+
from langchain.agents.middleware.types import ModelResponse
1011
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
1112
from testcontainers.redis import RedisContainer
1213

@@ -409,10 +410,12 @@ async def test_session_tag_isolation(self, redis_url: str) -> None:
409410
await middleware_user1._ensure_initialized_async()
410411
await middleware_user2._ensure_initialized_async()
411412

412-
async def mock_llm(request: dict) -> dict:
413+
async def mock_llm(request: dict) -> ModelResponse:
413414
# Extract messages to check context
414415
msgs = request.get("messages", [])
415-
return {"content": f"Response with {len(msgs)} messages"}
416+
return ModelResponse(
417+
result=[AIMessage(content=f"Response with {len(msgs)} messages")]
418+
)
416419

417420
# User 1 sends a message
418421
request1 = {
@@ -430,6 +433,57 @@ async def mock_llm(request: dict) -> dict:
430433
await middleware_user1.aclose()
431434
await middleware_user2.aclose()
432435

436+
@pytest.mark.asyncio
437+
async def test_multi_turn_recall(self, redis_url: str) -> None:
438+
"""Test that middleware stores and retrieves messages across turns."""
439+
import uuid
440+
441+
unique_name = f"memory_recall_test_{uuid.uuid4().hex[:8]}"
442+
config = ConversationMemoryConfig(
443+
redis_url=redis_url,
444+
name=unique_name,
445+
session_tag="recall_test_session",
446+
top_k=5,
447+
distance_threshold=0.9,
448+
)
449+
450+
async with ConversationMemoryMiddleware(config) as middleware:
451+
injected_context: list = []
452+
453+
async def mock_llm(request: dict) -> ModelResponse:
454+
msgs = request.get("messages", [])
455+
# Track injected context messages (beyond the user's own message)
456+
injected_context.clear()
457+
injected_context.extend(msgs)
458+
return ModelResponse(result=[AIMessage(content="Got it, thanks!")])
459+
460+
# Turn 1: talk about Python programming
461+
request1 = {
462+
"messages": [
463+
HumanMessage(
464+
content="I really enjoy Python programming and data science"
465+
)
466+
]
467+
}
468+
await middleware.awrap_model_call(request1, mock_llm)
469+
470+
# Turn 2: ask about Python programming - semantically very similar
471+
request2 = {
472+
"messages": [
473+
HumanMessage(
474+
content="Tell me more about Python programming and data science"
475+
)
476+
]
477+
}
478+
await middleware.awrap_model_call(request2, mock_llm)
479+
480+
# The middleware should have injected context from Turn 1
481+
# Total messages should be more than just the 1 user message
482+
assert len(injected_context) > 1, (
483+
"Expected context injection from Turn 1 but got only the "
484+
f"user message. Messages: {injected_context}"
485+
)
486+
433487

434488
@requires_sentence_transformers
435489
class TestResponseSerialization:

tests/test_middleware_conversation_memory.py

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from unittest.mock import AsyncMock, MagicMock, patch
44

55
import pytest
6+
from langchain.agents.middleware.types import ModelResponse
7+
from langchain_core.messages import AIMessage
68

79
from langgraph.middleware.redis.types import ConversationMemoryConfig
810

@@ -88,11 +90,11 @@ async def test_retrieves_relevant_context(self) -> None:
8890
middleware = ConversationMemoryMiddleware(config)
8991
await middleware._ensure_initialized_async()
9092

91-
async def mock_handler(request: dict) -> dict:
93+
async def mock_handler(request: dict) -> ModelResponse:
9294
# Check that context was added (messages captured for potential assertions)
9395
_messages = request.get("messages", []) # noqa: F841
9496
# Should have injected context
95-
return {"content": "Response"}
97+
return ModelResponse(result=[AIMessage(content="Response")])
9698

9799
request = {"messages": [{"role": "user", "content": "New question"}]}
98100
await middleware.awrap_model_call(request, mock_handler)
@@ -120,14 +122,22 @@ async def test_stores_new_messages_after_response(self) -> None:
120122
middleware = ConversationMemoryMiddleware(config)
121123
await middleware._ensure_initialized_async()
122124

123-
async def mock_handler(request: dict) -> dict:
124-
return {"content": "Model response"}
125+
async def mock_handler(request: dict) -> ModelResponse:
126+
return ModelResponse(result=[AIMessage(content="Model response")])
125127

126128
request = {"messages": [{"role": "user", "content": "User question"}]}
127129
await middleware.awrap_model_call(request, mock_handler)
128130

129131
# Should have stored both user message and assistant response
130-
assert mock_history.add_messages.called
132+
assert mock_history.add_messages.call_count == 2
133+
# First call: user message
134+
user_call = mock_history.add_messages.call_args_list[0]
135+
assert user_call[0][0][0]["role"] == "user"
136+
assert user_call[0][0][0]["content"] == "User question"
137+
# Second call: assistant message with "llm" role for redisvl
138+
llm_call = mock_history.add_messages.call_args_list[1]
139+
assert llm_call[0][0][0]["role"] == "llm"
140+
assert llm_call[0][0][0]["content"] == "Model response"
131141

132142
@pytest.mark.asyncio
133143
async def test_uses_session_tag(self) -> None:
@@ -201,13 +211,14 @@ async def test_graceful_degradation_on_history_error(self) -> None:
201211
middleware = ConversationMemoryMiddleware(config)
202212
await middleware._ensure_initialized_async()
203213

204-
async def mock_handler(request: dict) -> dict:
205-
return {"content": "Handler response"}
214+
async def mock_handler(request: dict) -> ModelResponse:
215+
return ModelResponse(result=[AIMessage(content="Handler response")])
206216

207217
request = {"messages": [{"role": "user", "content": "Test"}]}
208218
result = await middleware.awrap_model_call(request, mock_handler)
209219

210-
assert result == {"content": "Handler response"}
220+
assert hasattr(result, "result")
221+
assert result.result[0].content == "Handler response"
211222

212223
@pytest.mark.asyncio
213224
async def test_raises_on_history_error_without_graceful_degradation(self) -> None:
@@ -231,8 +242,8 @@ async def test_raises_on_history_error_without_graceful_degradation(self) -> Non
231242
middleware = ConversationMemoryMiddleware(config)
232243
await middleware._ensure_initialized_async()
233244

234-
async def mock_handler(request: dict) -> dict:
235-
return {"content": "Handler response"}
245+
async def mock_handler(request: dict) -> ModelResponse:
246+
return ModelResponse(result=[AIMessage(content="Handler response")])
236247

237248
request = {"messages": [{"role": "user", "content": "Test"}]}
238249
with pytest.raises(Exception, match="Redis error"):
@@ -294,15 +305,21 @@ async def test_context_injection_format(self) -> None:
294305

295306
seen_messages = []
296307

297-
async def mock_handler(request: dict) -> dict:
308+
async def mock_handler(request: dict) -> ModelResponse:
298309
seen_messages.extend(request.get("messages", []))
299-
return {"content": "New response"}
310+
return ModelResponse(result=[AIMessage(content="New response")])
300311

301312
request = {"messages": [{"role": "user", "content": "Tell me more"}]}
302313
await middleware.awrap_model_call(request, mock_handler)
303314

304-
# Context should be injected before the current message
305-
assert len(seen_messages) >= 1
315+
# Context should be injected before the current message:
316+
# SystemMessage (context note) + 2 context messages + 1 user message
317+
assert len(seen_messages) == 4
318+
# First should be the context note SystemMessage
319+
from langchain_core.messages import SystemMessage
320+
321+
assert isinstance(seen_messages[0], SystemMessage)
322+
assert "context from previous" in seen_messages[0].content.lower()
306323

307324
@pytest.mark.asyncio
308325
async def test_tool_call_passes_through(self) -> None:
@@ -323,6 +340,48 @@ async def mock_handler(request: dict) -> dict:
323340

324341
assert result == {"result": "tool result"}
325342

343+
@pytest.mark.asyncio
344+
async def test_stores_messages_from_model_response(self) -> None:
345+
"""Test that both user and assistant messages are stored when handler returns ModelResponse."""
346+
from langgraph.middleware.redis.conversation_memory import (
347+
ConversationMemoryMiddleware,
348+
)
349+
350+
mock_client = AsyncMock()
351+
config = ConversationMemoryConfig(redis_client=mock_client)
352+
353+
with patch(
354+
"langgraph.middleware.redis.conversation_memory.SemanticMessageHistory"
355+
) as mock_history_class:
356+
mock_history = MagicMock()
357+
mock_history.get_relevant = MagicMock(return_value=[])
358+
mock_history.add_messages = MagicMock()
359+
mock_history_class.return_value = mock_history
360+
361+
middleware = ConversationMemoryMiddleware(config)
362+
await middleware._ensure_initialized_async()
363+
364+
async def mock_handler(request: dict) -> ModelResponse:
365+
return ModelResponse(
366+
result=[AIMessage(content="I'm doing great, thanks!")]
367+
)
368+
369+
request = {"messages": [{"role": "user", "content": "How are you?"}]}
370+
result = await middleware.awrap_model_call(request, mock_handler)
371+
372+
# Verify the response is a ModelResponse with the right content
373+
assert hasattr(result, "result")
374+
assert result.result[0].content == "I'm doing great, thanks!"
375+
376+
# Verify both messages were stored
377+
assert mock_history.add_messages.call_count == 2
378+
user_call = mock_history.add_messages.call_args_list[0]
379+
assert user_call[0][0] == [{"role": "user", "content": "How are you?"}]
380+
llm_call = mock_history.add_messages.call_args_list[1]
381+
assert llm_call[0][0] == [
382+
{"role": "llm", "content": "I'm doing great, thanks!"}
383+
]
384+
326385
@pytest.mark.asyncio
327386
async def test_handles_langchain_messages(self) -> None:
328387
"""Test handling of LangChain-style message objects."""

0 commit comments

Comments
 (0)