5757)
5858from haystack_experimental .components .retrievers import ChatMessageRetriever
5959from haystack_experimental .components .writers import ChatMessageWriter
60+ from haystack_experimental .memory_stores .types import MemoryStore
6061
6162logger = logging .getLogger (__name__ )
6263
@@ -146,6 +147,7 @@ def __init__(
146147 confirmation_strategies : dict [str , ConfirmationStrategy ] | None = None ,
147148 tool_invoker_kwargs : dict [str , Any ] | None = None ,
148149 chat_message_store : ChatMessageStore | None = None ,
150+ memory_store : MemoryStore | None = None ,
149151 ) -> None :
150152 """
151153 Initialize the agent component.
@@ -164,6 +166,9 @@ def __init__(
164166 :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
165167 If set to False, the exception will be turned into a chat message and passed to the LLM.
166168 :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
169+ :param chat_message_store: The ChatMessageStore that the agent can use to store
170+ and retrieve chat messages history.
171+ :param memory_store: The memory store that the agent can use to store and retrieve memories.
167172 :raises TypeError: If the chat_generator does not support tools parameter in its run method.
168173 :raises ValueError: If the exit_conditions are not valid.
169174 """
@@ -186,6 +191,7 @@ def __init__(
186191 self ._chat_message_writer = (
187192 ChatMessageWriter (chat_message_store = chat_message_store ) if chat_message_store else None
188193 )
194+ self ._memory_store = memory_store
189195
190196 def _initialize_fresh_execution (
191197 self ,
@@ -198,6 +204,7 @@ def _initialize_fresh_execution(
198204 tools : ToolsType | list [str ] | None = None ,
199205 confirmation_strategy_context : dict [str , Any ] | None = None ,
200206 chat_message_store_kwargs : dict [str , Any ] | None = None ,
207+ memory_store_kwargs : dict [str , Any ] | None = None ,
201208 ** kwargs : dict [str , Any ],
202209 ) -> _ExecutionContext :
203210 """
@@ -209,29 +216,62 @@ def _initialize_fresh_execution(
209216 :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
210217 :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
211218 When passing tool names, tools are selected from the Agent's originally configured tools.
219+
220+ :param memory_store_kwargs: Optional dictionary of keyword arguments to pass to the MemoryStore.
221+ For example, it can include the `user_id`, `run_id`, and `agent_id` parameters
222+ for storing and retrieving memories.
212223 :param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
213224 to confirmation strategies.
214225 :param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
226+ For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
215227 :param kwargs: Additional data to pass to the State used by the Agent.
216228 """
217229 system_prompt = system_prompt or self .system_prompt
218- if system_prompt is not None :
219- messages = [ChatMessage .from_system (system_prompt )] + messages
230+ retrieved_memory = None
231+ updated_system_prompt = system_prompt
232+
233+ # Retrieve memories from the memory store
234+ if self ._memory_store :
235+ retrieved_memories = self ._memory_store .search_memories (query = messages [- 1 ].text , ** memory_store_kwargs ) # type: ignore[arg-type]
236+
237+ # we combine the memories into a single string
238+ combined_memory = "\n " .join (
239+ f"- MEMORY #{ idx + 1 } : { memory .text } " for idx , memory in enumerate (retrieved_memories )
240+ )
241+ retrieved_memory = ChatMessage .from_system (text = combined_memory )
242+
243+ if retrieved_memory :
244+ memory_instruction = (
245+ "\n \n When messages start with `[MEMORY]`, treat them as long-term "
246+ "context and use them to guide the response if relevant."
247+ )
248+ updated_system_prompt = f"{ system_prompt } { memory_instruction } "
249+
250+ memory_text = f"Here are the relevant memories for the user's query: { retrieved_memory .text } "
251+ print (memory_text )
252+ updated_memory = ChatMessage .from_system (text = memory_text )
253+ else :
254+ updated_memory = None
255+
256+ combined_messages = messages + [updated_memory ] if updated_memory else messages
257+ if updated_system_prompt is not None :
258+ combined_messages = [ChatMessage .from_system (updated_system_prompt )] + combined_messages
220259
221260 # NOTE: difference with parent method to add chat message retrieval
222261 if self ._chat_message_retriever :
223262 retriever_kwargs = _select_kwargs (self ._chat_message_retriever , chat_message_store_kwargs or {})
224263 if "chat_history_id" in retriever_kwargs :
225264 messages = self ._chat_message_retriever .run (
226- current_messages = messages ,
265+ current_messages = combined_messages ,
227266 ** retriever_kwargs ,
228267 )["messages" ]
268+ combined_messages = messages
229269
230- if all (m .is_from (ChatRole .SYSTEM ) for m in messages ):
270+ if all (m .is_from (ChatRole .SYSTEM ) for m in combined_messages ):
231271 logger .warning ("All messages provided to the Agent component are system messages. This is not recommended." )
232272
233273 state = State (schema = self .state_schema , data = kwargs )
234- state .set ("messages" , messages )
274+ state .set ("messages" , combined_messages )
235275
236276 streaming_callback = select_streaming_callback ( # type: ignore[call-overload]
237277 init_callback = self .streaming_callback , runtime_callback = streaming_callback , requires_async = requires_async
@@ -329,6 +369,7 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
329369 tools : ToolsType | list [str ] | None = None ,
330370 confirmation_strategy_context : dict [str , Any ] | None = None ,
331371 chat_message_store_kwargs : dict [str , Any ] | None = None ,
372+ memory_store_kwargs : dict [str , Any ] | None = None ,
332373 ** kwargs : Any ,
333374 ) -> dict [str , Any ]:
334375 """
@@ -352,6 +393,19 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
352393 can use for non-blocking user interaction.
353394 :param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
354395 For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
396+ :param memory_store_kwargs: Optional dictionary of keyword arguments to pass to the MemoryStore.
397+ It can include:
398+ - `user_id`: The user ID to search and add memories from.
399+ - `run_id`: The run ID to search and add memories from.
400+ - `agent_id`: The agent ID to search and add memories from.
401+ - `search_criteria`: A dictionary of containing kwargs for the `search_memories` method.
402+ This can include:
403+ - `filters`: A dictionary of filters to search for memories.
404+ - `query`: The query to search for memories.
405+ Note: If you pass this, the user query passed to the agent will be
406+ ignored for memory retrieval.
407+ - `top_k`: The number of memories to return.
408+ - `include_memory_metadata`: Whether to include the memory metadata in the ChatMessage.
355409 :param kwargs: Additional data to pass to the State schema used by the Agent.
356410 The keys must match the schema defined in the Agent's `state_schema`.
357411 :returns:
@@ -362,6 +416,8 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
362416 :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
363417 :raises BreakpointException: If an agent breakpoint is triggered.
364418 """
419+ memory_store_kwargs = memory_store_kwargs or {}
420+
365421 agent_inputs = {
366422 "messages" : messages ,
367423 "streaming_callback" : streaming_callback ,
@@ -392,6 +448,7 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
392448 tools = tools ,
393449 confirmation_strategy_context = confirmation_strategy_context ,
394450 chat_message_store_kwargs = chat_message_store_kwargs ,
451+ memory_store_kwargs = memory_store_kwargs ,
395452 ** kwargs ,
396453 )
397454
@@ -547,6 +604,11 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
547604 if msgs := result .get ("messages" ):
548605 result ["last_message" ] = msgs [- 1 ]
549606
607+ # Add the new conversation as memories to the memory store
608+ if self ._memory_store :
609+ new_memories = [message for message in msgs if message .role .value != "system" ]
610+ self ._memory_store .add_memories (messages = new_memories , ** memory_store_kwargs )
611+
550612 # Write messages to ChatMessageStore if configured
551613 if self ._chat_message_writer :
552614 writer_kwargs = _select_kwargs (self ._chat_message_writer , chat_message_store_kwargs or {})
@@ -567,6 +629,7 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
567629 tools : ToolsType | list [str ] | None = None ,
568630 confirmation_strategy_context : dict [str , Any ] | None = None ,
569631 chat_message_store_kwargs : dict [str , Any ] | None = None ,
632+ memory_store_kwargs : dict [str , Any ] | None = None ,
570633 ** kwargs : Any ,
571634 ) -> dict [str , Any ]:
572635 """
@@ -593,6 +656,20 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
593656 can use for non-blocking user interaction.
594657 :param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
595658 For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
659+ :param kwargs: Additional data to pass to the State schema used by the Agent.
660+ :param memory_store_kwargs: Optional dictionary of keyword arguments to pass to the MemoryStore.
661+ It can include:
662+ - `user_id`: The user ID to search and add memories from.
663+ - `run_id`: The run ID to search and add memories from.
664+ - `agent_id`: The agent ID to search and add memories from.
665+ - `search_criteria`: A dictionary of containing kwargs for the `search_memories` method.
666+ This can include:
667+ - `filters`: A dictionary of filters to search for memories.
668+ - `query`: The query to search for memories.
669+ Note: If you pass this, the user query passed to the agent will be
670+ ignored for memory retrieval.
671+ - `top_k`: The number of memories to return.
672+ - `include_memory_metadata`: Whether to include the memory metadata in the ChatMessage.
596673 :param kwargs: Additional data to pass to the State schema used by the Agent.
597674 The keys must match the schema defined in the Agent's `state_schema`.
598675 :returns:
@@ -603,6 +680,8 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
603680 :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`.
604681 :raises BreakpointException: If an agent breakpoint is triggered.
605682 """
683+ memory_store_kwargs = memory_store_kwargs or {}
684+
606685 agent_inputs = {
607686 "messages" : messages ,
608687 "streaming_callback" : streaming_callback ,
@@ -631,6 +710,7 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
631710 tools = tools ,
632711 confirmation_strategy_context = confirmation_strategy_context ,
633712 chat_message_store_kwargs = chat_message_store_kwargs ,
713+ memory_store_kwargs = memory_store_kwargs ,
634714 ** kwargs ,
635715 )
636716
@@ -773,6 +853,11 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
773853 if msgs := result .get ("messages" ):
774854 result ["last_message" ] = msgs [- 1 ]
775855
856+ # Add the new conversation as memories to the memory store
857+ if self ._memory_store :
858+ new_memories = [message for message in msgs if message .role .value != "system" ]
859+ self ._memory_store .add_memories (messages = new_memories , ** memory_store_kwargs )
860+
776861 # Write messages to ChatMessageStore if configured
777862 if self ._chat_message_writer :
778863 writer_kwargs = _select_kwargs (self ._chat_message_writer , chat_message_store_kwargs or {})
0 commit comments