Skip to content

Commit ea795ba

Browse files
committed
feat: support separate turns container for short-term/long-term memory separation
Add optional cosmos_turns_container parameter to CosmosMemoryClient and AsyncCosmosMemoryClient that routes turn-type documents to a dedicated container while keeping derived memories (facts, summaries, episodic, procedural) in the main container. - New _container_for_type() routes writes by memory_type - New _container_for_query() returns single container for simple cases - New _containers_for_query() returns list of containers for multi-container merging - get_thread uses _containers_for_query and merges with chronological sort - ProcessingPipeline now accepts cosmos_turns_container and queries both containers for thread-wide operations (extract_memories, generate_user_summary) - create_memory_store provisions turns container when configured - update_cosmos/delete_cosmos check both containers (main first, then turns) - close/drain methods clear _turns_container_client - Fully backward compatible: omit the param for single-container behavior
1 parent 2533cd7 commit ea795ba

3 files changed

Lines changed: 333 additions & 18 deletions

File tree

agent_memory_toolkit/aio/cosmos_memory_client.py

Lines changed: 144 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
cosmos_key: Optional[str] = None,
9999
cosmos_database: Optional[str] = None,
100100
cosmos_container: Optional[str] = None,
101+
cosmos_turns_container: Optional[str] = None,
101102
cosmos_counter_container: Optional[str] = None,
102103
cosmos_lease_container: Optional[str] = None,
103104
cosmos_throughput_mode: Optional[str] = None,
@@ -131,6 +132,7 @@ def __init__(
131132
self._cosmos_key = cosmos_key
132133
self._cosmos_database = cosmos_database or "ai_memory"
133134
self._cosmos_container = cosmos_container or "memories"
135+
self._cosmos_turns_container = cosmos_turns_container # None means use main container
134136
self._cosmos_counter_container = cosmos_counter_container or "counter"
135137
self._cosmos_lease_container = cosmos_lease_container or "leases"
136138
self._cosmos_throughput_mode = _resolve_cosmos_throughput_mode(cosmos_throughput_mode)
@@ -208,6 +210,7 @@ def __init__(
208210
# Internal Cosmos SDK handles
209211
self._cosmos_client: Any = None
210212
self._container_client: Any = None
213+
self._turns_container_client: Any = None # Separate container for turns (optional)
211214
self._counter_container_client: Any = None
212215

213216
# Composed sub-clients
@@ -274,6 +277,7 @@ async def close(self) -> None:
274277
await self._cosmos_client.close()
275278
self._cosmos_client = None
276279
self._container_client = None
280+
self._turns_container_client = None
277281
self._counter_container_client = None
278282
if self._sync_cosmos_client is not None:
279283
close = getattr(self._sync_cosmos_client, "close", None)
@@ -505,6 +509,19 @@ async def connect_cosmos(
505509

506510
self._cosmos_client = client
507511
self._container_client = container_handle
512+
513+
# Connect turns container if configured separately
514+
if self._cosmos_turns_container:
515+
turns_handle = db.get_container_client(self._cosmos_turns_container)
516+
self._turns_container_client = turns_handle
517+
logger.info(
518+
"Async connected turns container: %s/%s",
519+
self._cosmos_database,
520+
self._cosmos_turns_container,
521+
)
522+
else:
523+
self._turns_container_client = None
524+
508525
self._init_pipeline()
509526
except Exception as exc:
510527
raise CosmosOperationError(f"Failed to connect to Cosmos DB (async): {exc}") from exc
@@ -519,6 +536,7 @@ async def create_memory_store(
519536
self,
520537
database: Optional[str] = None,
521538
container: Optional[str] = None,
539+
turns_container: Optional[str] = None,
522540
counter_container: Optional[str] = None,
523541
lease_container: Optional[str] = None,
524542
endpoint: Optional[str] = None,
@@ -560,6 +578,8 @@ async def create_memory_store(
560578
self._cosmos_key = key
561579
self._cosmos_database = database or self._cosmos_database
562580
self._cosmos_container = container or self._cosmos_container
581+
if turns_container is not None:
582+
self._cosmos_turns_container = turns_container
563583
self._cosmos_counter_container = counter_container or self._cosmos_counter_container
564584
self._cosmos_lease_container = lease_container or self._cosmos_lease_container
565585
self._cosmos_throughput_mode = _resolve_cosmos_throughput_mode(
@@ -630,6 +650,29 @@ async def create_memory_store(
630650
)
631651
self._cosmos_client = client
632652
self._container_client = container_handle
653+
654+
# Create and connect separate turns container if configured
655+
if self._cosmos_turns_container:
656+
turns_handle = await db.create_container_if_not_exists(
657+
**_build_container_kwargs(
658+
container_id=self._cosmos_turns_container,
659+
partition_key=partition_key,
660+
offer_throughput=offer_throughput,
661+
default_ttl=-1,
662+
indexing_policy=idx_policy,
663+
vector_embedding_policy=vec_policy,
664+
full_text_policy=ft_policy,
665+
)
666+
)
667+
self._turns_container_client = turns_handle
668+
logger.info(
669+
"Created turns container: %s/%s",
670+
self._cosmos_database,
671+
self._cosmos_turns_container,
672+
)
673+
else:
674+
self._turns_container_client = None
675+
633676
self._init_pipeline()
634677
except Exception as exc:
635678
raise CosmosOperationError(f"Failed to create memory store (async): {exc}") from exc
@@ -647,6 +690,60 @@ async def _require_cosmos(self) -> None:
647690
if self._container_client is None:
648691
raise CosmosNotConnectedError()
649692

693+
def _container_for_type(self, memory_type: str) -> Any:
694+
"""Return the appropriate container client based on memory type."""
695+
if memory_type == "turn" and self._turns_container_client is not None:
696+
return self._turns_container_client
697+
return self._container_client
698+
699+
def _container_for_query(self, memory_types: Optional[list[str]] = None) -> Any:
700+
"""Return the appropriate container for a read query.
701+
702+
In the default single-container configuration this returns only the
703+
main memories container. When a dedicated turns container is
704+
configured:
705+
706+
* turn-only queries should search only the turns container
707+
* non-turn-only queries should search only the main memories container
708+
* mixed or unspecified queries may need to inspect both containers
709+
"""
710+
if not memory_types:
711+
return self._container_client
712+
713+
has_turn = any(t == "turn" for t in memory_types)
714+
has_not_turn = any(t != "turn" for t in memory_types)
715+
716+
if self._turns_container_client is not None:
717+
if has_turn and not has_not_turn:
718+
return self._turns_container_client
719+
if not has_turn:
720+
return self._container_client
721+
return self._container_client
722+
return self._container_client
723+
724+
def _containers_for_query(self, memory_types: Optional[list[str]] = None) -> list[Any]:
725+
"""Return candidate containers for a read query.
726+
727+
Unlike :meth:`_container_for_query` which returns a single container,
728+
this returns all containers that should be queried to get complete
729+
results, enabling callers to merge across containers.
730+
"""
731+
if self._turns_container_client is None:
732+
return [self._container_client]
733+
734+
if not memory_types:
735+
return [self._container_client, self._turns_container_client]
736+
737+
has_turn = any(t == "turn" for t in memory_types)
738+
has_not_turn = any(t != "turn" for t in memory_types)
739+
740+
if has_turn and has_not_turn:
741+
return [self._container_client, self._turns_container_client]
742+
743+
if has_turn:
744+
return [self._turns_container_client]
745+
return [self._container_client]
746+
650747
# ------------------------------------------------------------------
651748
# Cosmos DB CRUD operations (async)
652749
# ------------------------------------------------------------------
@@ -705,7 +802,8 @@ async def add_cosmos(
705802
)
706803

707804
try:
708-
await self._container_client.upsert_item(body=body)
805+
container = self._container_for_type(memory_type)
806+
await container.upsert_item(body=body)
709807
except Exception as exc:
710808
raise CosmosOperationError(f"Async upsert failed for record {record.id}: {exc}") from exc
711809
logger.info("add_cosmos id=%s role=%s type=%s", record.id, role, memory_type)
@@ -770,7 +868,7 @@ async def push_to_cosmos(self, batch_size: int = 25) -> None:
770868
len(to_embed_text),
771869
)
772870

773-
tasks = [self._container_client.upsert_item(body=b) for b in bodies]
871+
tasks = [self._container_for_type(b.get("type", "turn")).upsert_item(body=b) for b in bodies]
774872
try:
775873
await asyncio.gather(*tasks)
776874
except Exception as exc:
@@ -870,7 +968,8 @@ async def get_memories(
870968
logger.debug("async get_memories query: %s", query)
871969

872970
try:
873-
items_iter = self._container_client.query_items(
971+
container = self._container_for_query(memory_types)
972+
items_iter = container.query_items(
874973
query=query,
875974
parameters=parameters or None,
876975
)
@@ -914,6 +1013,14 @@ async def update_cosmos(
9141013
parameters=[{"name": "@id", "value": memory_id}],
9151014
)
9161015
docs = [item async for item in items_iter]
1016+
target_container = self._container_client
1017+
if not docs and self._turns_container_client is not None:
1018+
items_iter = self._turns_container_client.query_items(
1019+
query="SELECT * FROM c WHERE c.id = @id",
1020+
parameters=[{"name": "@id", "value": memory_id}],
1021+
)
1022+
docs = [item async for item in items_iter]
1023+
target_container = self._turns_container_client
9171024
except Exception as exc:
9181025
raise CosmosOperationError(f"async update query failed: {exc}") from exc
9191026

@@ -932,7 +1039,7 @@ async def update_cosmos(
9321039
doc["updated_at"] = datetime.now(timezone.utc).isoformat()
9331040

9341041
try:
935-
await self._container_client.replace_item(item=doc["id"], body=doc)
1042+
await target_container.replace_item(item=doc["id"], body=doc)
9361043
except Exception as exc:
9371044
raise CosmosOperationError(f"async update replace failed for {memory_id}: {exc}") from exc
9381045

@@ -962,14 +1069,29 @@ async def delete_cosmos(self, memory_id: str, thread_id: str, user_id: str) -> N
9621069
],
9631070
)
9641071
docs = [item async for item in items_iter]
1072+
target_container = self._container_client
1073+
if not docs and self._turns_container_client is not None:
1074+
items_iter = self._turns_container_client.query_items(
1075+
query=(
1076+
"SELECT TOP 1 c.id FROM c WHERE c.id = @id"
1077+
" AND c.thread_id = @thread_id AND c.user_id = @user_id"
1078+
),
1079+
parameters=[
1080+
{"name": "@id", "value": memory_id},
1081+
{"name": "@thread_id", "value": thread_id},
1082+
{"name": "@user_id", "value": user_id},
1083+
],
1084+
)
1085+
docs = [item async for item in items_iter]
1086+
target_container = self._turns_container_client
9651087
except Exception as exc:
9661088
raise CosmosOperationError(f"async delete lookup failed: {exc}") from exc
9671089

9681090
if not docs:
9691091
raise MemoryNotFoundError(memory_id=memory_id, user_id=user_id, thread_id=thread_id)
9701092

9711093
try:
972-
await self._container_client.delete_item(item=memory_id, partition_key=[user_id, thread_id])
1094+
await target_container.delete_item(item=memory_id, partition_key=[user_id, thread_id])
9731095
except Exception as exc:
9741096
raise CosmosOperationError(f"async delete failed for {memory_id}: {exc}") from exc
9751097

@@ -1126,8 +1248,14 @@ async def get_thread(
11261248
logger.debug("async get_thread query: %s", query)
11271249

11281250
try:
1129-
items_iter = self._container_client.query_items(query=query, parameters=parameters)
1130-
items = [item async for item in items_iter]
1251+
containers = self._containers_for_query(memory_types)
1252+
items = []
1253+
for container in containers:
1254+
items_iter = container.query_items(query=query, parameters=parameters)
1255+
items.extend([item async for item in items_iter])
1256+
# Re-sort merged results to preserve chronological ordering
1257+
if len(containers) > 1:
1258+
items.sort(key=lambda x: x.get("created_at", ""), reverse=True)
11311259
except Exception as exc:
11321260
raise CosmosOperationError(f"async get_thread query failed: {exc}") from exc
11331261

@@ -1317,6 +1445,7 @@ async def _drain_cosmos_client(self) -> None:
13171445
logger.warning("Failed to close prior async Cosmos client during reconnect", exc_info=True)
13181446
self._cosmos_client = None
13191447
self._container_client = None
1448+
self._turns_container_client = None
13201449
self._counter_container_client = None
13211450
self._pipeline = None
13221451
if not self._processor_explicit:
@@ -1368,10 +1497,18 @@ def _init_pipeline(self) -> None:
13681497
self._sync_embeddings_client = sync_embeddings
13691498

13701499
if sync_container is not None:
1500+
# Build sync turns container if configured
1501+
sync_turns_container = None
1502+
if self._cosmos_turns_container:
1503+
try:
1504+
sync_turns_container = sync_db.get_container_client(self._cosmos_turns_container)
1505+
except Exception:
1506+
logger.warning("Failed to create sync turns container for pipeline")
13711507
self._pipeline = ProcessingPipeline(
13721508
cosmos_container=sync_container,
13731509
chat_client=self._chat_client,
13741510
embeddings_client=sync_embeddings,
1511+
cosmos_turns_container=sync_turns_container,
13751512
)
13761513
self._warn_on_embedding_dim_mismatch(sync_container)
13771514

0 commit comments

Comments
 (0)