Skip to content

Commit eb937c7

Browse files
Amnah199sjrl
andauthored
Add Mem0 integration - support for Mem0 platform (#391)
* Add mem0 integration * Add custom prompt * Udpates * Update tests * Fix errors and update * Fix licenses * Updates * Fixes * Fix linting * Fix linting * Updates * Update agent logic * Fix linting * fix linting * Fix linting * Update agent * Update tests * PR comments * PR comments * Use experimental agent * PR comments * Fix linting * Update tests * Fix linting * PR comments * PR comments * Update the init file * Remove config * Retrieve memories as system messages * Add missing init files * Fixes * More linting issues * Try fixing import error * add memory_store * Update haystack_experimental/components/agents/agent.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Add types * Update add logic * Fix version * Add new search method * Remove print statements * Fix linting error * Update tests * Update the dependency * Test update * Fix bug * Add filter conversion * Remove example file * PR comments * Fix tests * Fix linting * Fix tests * Update pydocs * Update types * Add integration tests * Update the tests * Update workflow * Update workflow * Update memory store fixture * Update workflow * Update workflow * Add permission * PR comments * Add license * Add pydocs --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
1 parent 7ec5be4 commit eb937c7

13 files changed

Lines changed: 892 additions & 10 deletions

File tree

.github/workflows/tests.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,20 @@ on:
2424
- "pyproject.toml"
2525
- ".github/workflows/tests.yml"
2626

27+
permissions:
28+
id-token: write
29+
contents: read
30+
2731
env:
2832
PYTHON_VERSION: "3.10"
2933
HATCH_VERSION: "1.14.2"
3034
PYTHONUNBUFFERED: "1"
3135
FORCE_COLOR: "1"
3236
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
37+
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
38+
MEM0_API_KEY: ${{ secrets.MEM0_API_KEY }}
39+
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
40+
AWS_REGION: "us-east-1"
3341
jobs:
3442
linting:
3543
runs-on: ubuntu-latest
@@ -129,5 +137,16 @@ jobs:
129137

130138
- name: Install Hatch
131139
run: pip install hatch==${{ env.HATCH_VERSION }}
140+
141+
# Do not authenticate on PRs from forks and on PRs created by dependabot
142+
- name: AWS authentication
143+
id: aws-auth
144+
if: github.event_name == 'schedule' || (github.event.pull_request.head.repo.full_name == github.repository && !startsWith(github.event.pull_request.head.ref, 'dependabot/'))
145+
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
146+
with:
147+
aws-region: ${{ env.AWS_REGION }}
148+
role-to-assume: ${{ secrets.AWS_CI_ROLE_ARN }}
149+
132150
- name: Run
151+
if: success() && steps.aws-auth.outcome == 'success'
133152
run: hatch run test:integration-retry

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b
4848
| [`Agent`][17]; [Confirmation Policies][18]; [ConfirmationUIs][19]; [ConfirmationStrategies][20]; [`ConfirmationUIResult` and `ToolExecutionDecision`][21] [HITLBreakpointException][22] | Human in the Loop | December 2025 | rich | None | [Discuss][23] |
4949
| [`LLMSummarizer`][24] | Document Summarizer | January 2025 | None | None | [Discuss][25] |
5050
| [`InMemoryChatMessageStore`][1]; [`ChatMessageRetriever`][2]; [`ChatMessageWriter`][3] | Chat Message Store, Retriever, Writer | February 2025 | None | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/conversational_rag_using_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss][4] |
51+
| [`Mem0MemoryStore`][26] | MemoryStore | February 2025 | mem0ai | None | -- |
5152

5253
[1]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/chat_message_stores/in_memory.py
5354
[2]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/chat_message_retriever.py
@@ -66,6 +67,7 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b
6667
[23]: https://github.com/deepset-ai/haystack-experimental/discussions/381
6768
[24]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/sumarizers/llm_summarizer.py
6869
[25]: https://github.com/deepset-ai/haystack-experimental/discussions/382
70+
[26]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/memory_stores/mem0/memory_store.py
6971

7072
### Adopted experiments
7173
| Name | Type | Final release |

haystack_experimental/components/agents/agent.py

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
)
5858
from haystack_experimental.components.retrievers import ChatMessageRetriever
5959
from haystack_experimental.components.writers import ChatMessageWriter
60+
from haystack_experimental.memory_stores.types import MemoryStore
6061

6162
logger = logging.getLogger(__name__)
6263

@@ -146,6 +147,7 @@ def __init__(
146147
confirmation_strategies: dict[str, ConfirmationStrategy] | None = None,
147148
tool_invoker_kwargs: dict[str, Any] | None = None,
148149
chat_message_store: ChatMessageStore | None = None,
150+
memory_store: MemoryStore | None = None,
149151
) -> None:
150152
"""
151153
Initialize the agent component.
@@ -164,6 +166,9 @@ def __init__(
164166
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
165167
If set to False, the exception will be turned into a chat message and passed to the LLM.
166168
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
169+
:param chat_message_store: The ChatMessageStore that the agent can use to store
170+
and retrieve chat messages history.
171+
:param memory_store: The memory store that the agent can use to store and retrieve memories.
167172
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
168173
:raises ValueError: If the exit_conditions are not valid.
169174
"""
@@ -186,6 +191,7 @@ def __init__(
186191
self._chat_message_writer = (
187192
ChatMessageWriter(chat_message_store=chat_message_store) if chat_message_store else None
188193
)
194+
self._memory_store = memory_store
189195

190196
def _initialize_fresh_execution(
191197
self,
@@ -198,6 +204,7 @@ def _initialize_fresh_execution(
198204
tools: ToolsType | list[str] | None = None,
199205
confirmation_strategy_context: dict[str, Any] | None = None,
200206
chat_message_store_kwargs: dict[str, Any] | None = None,
207+
memory_store_kwargs: dict[str, Any] | None = None,
201208
**kwargs: dict[str, Any],
202209
) -> _ExecutionContext:
203210
"""
@@ -209,29 +216,62 @@ def _initialize_fresh_execution(
209216
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
210217
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
211218
When passing tool names, tools are selected from the Agent's originally configured tools.
219+
220+
:param memory_store_kwargs: Optional dictionary of keyword arguments to pass to the MemoryStore.
221+
For example, it can include the `user_id`, `run_id`, and `agent_id` parameters
222+
for storing and retrieving memories.
212223
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
213224
to confirmation strategies.
214225
:param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
226+
For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
215227
:param kwargs: Additional data to pass to the State used by the Agent.
216228
"""
217229
system_prompt = system_prompt or self.system_prompt
218-
if system_prompt is not None:
219-
messages = [ChatMessage.from_system(system_prompt)] + messages
230+
retrieved_memory = None
231+
updated_system_prompt = system_prompt
232+
233+
# Retrieve memories from the memory store
234+
if self._memory_store:
235+
retrieved_memories = self._memory_store.search_memories(query=messages[-1].text, **memory_store_kwargs) # type: ignore[arg-type]
236+
237+
# we combine the memories into a single string
238+
combined_memory = "\n".join(
239+
f"- MEMORY #{idx + 1}: {memory.text}" for idx, memory in enumerate(retrieved_memories)
240+
)
241+
retrieved_memory = ChatMessage.from_system(text=combined_memory)
242+
243+
if retrieved_memory:
244+
memory_instruction = (
245+
"\n\nWhen messages start with `[MEMORY]`, treat them as long-term "
246+
"context and use them to guide the response if relevant."
247+
)
248+
updated_system_prompt = f"{system_prompt}{memory_instruction}"
249+
250+
memory_text = f"Here are the relevant memories for the user's query: {retrieved_memory.text}"
251+
print(memory_text)
252+
updated_memory = ChatMessage.from_system(text=memory_text)
253+
else:
254+
updated_memory = None
255+
256+
combined_messages = messages + [updated_memory] if updated_memory else messages
257+
if updated_system_prompt is not None:
258+
combined_messages = [ChatMessage.from_system(updated_system_prompt)] + combined_messages
220259

221260
# NOTE: difference with parent method to add chat message retrieval
222261
if self._chat_message_retriever:
223262
retriever_kwargs = _select_kwargs(self._chat_message_retriever, chat_message_store_kwargs or {})
224263
if "chat_history_id" in retriever_kwargs:
225264
messages = self._chat_message_retriever.run(
226-
current_messages=messages,
265+
current_messages=combined_messages,
227266
**retriever_kwargs,
228267
)["messages"]
268+
combined_messages = messages
229269

230-
if all(m.is_from(ChatRole.SYSTEM) for m in messages):
270+
if all(m.is_from(ChatRole.SYSTEM) for m in combined_messages):
231271
logger.warning("All messages provided to the Agent component are system messages. This is not recommended.")
232272

233273
state = State(schema=self.state_schema, data=kwargs)
234-
state.set("messages", messages)
274+
state.set("messages", combined_messages)
235275

236276
streaming_callback = select_streaming_callback( # type: ignore[call-overload]
237277
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=requires_async
@@ -329,6 +369,7 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
329369
tools: ToolsType | list[str] | None = None,
330370
confirmation_strategy_context: dict[str, Any] | None = None,
331371
chat_message_store_kwargs: dict[str, Any] | None = None,
372+
memory_store_kwargs: dict[str, Any] | None = None,
332373
**kwargs: Any,
333374
) -> dict[str, Any]:
334375
"""
@@ -352,6 +393,19 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
352393
can use for non-blocking user interaction.
353394
:param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
354395
For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
396+
:param memory_store_kwargs: Optional dictionary of keyword arguments to pass to the MemoryStore.
397+
It can include:
398+
- `user_id`: The user ID to search and add memories from.
399+
- `run_id`: The run ID to search and add memories from.
400+
- `agent_id`: The agent ID to search and add memories from.
401+
- `search_criteria`: A dictionary of containing kwargs for the `search_memories` method.
402+
This can include:
403+
- `filters`: A dictionary of filters to search for memories.
404+
- `query`: The query to search for memories.
405+
Note: If you pass this, the user query passed to the agent will be
406+
ignored for memory retrieval.
407+
- `top_k`: The number of memories to return.
408+
- `include_memory_metadata`: Whether to include the memory metadata in the ChatMessage.
355409
:param kwargs: Additional data to pass to the State schema used by the Agent.
356410
The keys must match the schema defined in the Agent's `state_schema`.
357411
:returns:
@@ -362,6 +416,8 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
362416
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
363417
:raises BreakpointException: If an agent breakpoint is triggered.
364418
"""
419+
memory_store_kwargs = memory_store_kwargs or {}
420+
365421
agent_inputs = {
366422
"messages": messages,
367423
"streaming_callback": streaming_callback,
@@ -392,6 +448,7 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
392448
tools=tools,
393449
confirmation_strategy_context=confirmation_strategy_context,
394450
chat_message_store_kwargs=chat_message_store_kwargs,
451+
memory_store_kwargs=memory_store_kwargs,
395452
**kwargs,
396453
)
397454

@@ -547,6 +604,11 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
547604
if msgs := result.get("messages"):
548605
result["last_message"] = msgs[-1]
549606

607+
# Add the new conversation as memories to the memory store
608+
if self._memory_store:
609+
new_memories = [message for message in msgs if message.role.value != "system"]
610+
self._memory_store.add_memories(messages=new_memories, **memory_store_kwargs)
611+
550612
# Write messages to ChatMessageStore if configured
551613
if self._chat_message_writer:
552614
writer_kwargs = _select_kwargs(self._chat_message_writer, chat_message_store_kwargs or {})
@@ -567,6 +629,7 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
567629
tools: ToolsType | list[str] | None = None,
568630
confirmation_strategy_context: dict[str, Any] | None = None,
569631
chat_message_store_kwargs: dict[str, Any] | None = None,
632+
memory_store_kwargs: dict[str, Any] | None = None,
570633
**kwargs: Any,
571634
) -> dict[str, Any]:
572635
"""
@@ -593,6 +656,20 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
593656
can use for non-blocking user interaction.
594657
:param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
595658
For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
659+
:param kwargs: Additional data to pass to the State schema used by the Agent.
660+
:param memory_store_kwargs: Optional dictionary of keyword arguments to pass to the MemoryStore.
661+
It can include:
662+
- `user_id`: The user ID to search and add memories from.
663+
- `run_id`: The run ID to search and add memories from.
664+
- `agent_id`: The agent ID to search and add memories from.
665+
- `search_criteria`: A dictionary of containing kwargs for the `search_memories` method.
666+
This can include:
667+
- `filters`: A dictionary of filters to search for memories.
668+
- `query`: The query to search for memories.
669+
Note: If you pass this, the user query passed to the agent will be
670+
ignored for memory retrieval.
671+
- `top_k`: The number of memories to return.
672+
- `include_memory_metadata`: Whether to include the memory metadata in the ChatMessage.
596673
:param kwargs: Additional data to pass to the State schema used by the Agent.
597674
The keys must match the schema defined in the Agent's `state_schema`.
598675
:returns:
@@ -603,6 +680,8 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
603680
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`.
604681
:raises BreakpointException: If an agent breakpoint is triggered.
605682
"""
683+
memory_store_kwargs = memory_store_kwargs or {}
684+
606685
agent_inputs = {
607686
"messages": messages,
608687
"streaming_callback": streaming_callback,
@@ -631,6 +710,7 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
631710
tools=tools,
632711
confirmation_strategy_context=confirmation_strategy_context,
633712
chat_message_store_kwargs=chat_message_store_kwargs,
713+
memory_store_kwargs=memory_store_kwargs,
634714
**kwargs,
635715
)
636716

@@ -773,6 +853,11 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
773853
if msgs := result.get("messages"):
774854
result["last_message"] = msgs[-1]
775855

856+
# Add the new conversation as memories to the memory store
857+
if self._memory_store:
858+
new_memories = [message for message in msgs if message.role.value != "system"]
859+
self._memory_store.add_memories(messages=new_memories, **memory_store_kwargs)
860+
776861
# Write messages to ChatMessageStore if configured
777862
if self._chat_message_writer:
778863
writer_kwargs = _select_kwargs(self._chat_message_writer, chat_message_store_kwargs or {})
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from .types import MemoryStore
6+
7+
__all__ = ["MemoryStore"]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import sys
6+
from typing import TYPE_CHECKING
7+
8+
from lazy_imports import LazyImporter
9+
10+
_import_structure = {"memory_store": ["Mem0MemoryStore"]}
11+
12+
if TYPE_CHECKING:
13+
from .memory_store import Mem0MemoryStore as Mem0MemoryStore
14+
15+
else:
16+
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)

0 commit comments

Comments
 (0)