Skip to content

Commit ab61928

Browse files
authored
feat: support separate turns container for short-term/long-term memory separation (#14)
Add optional cosmos_turns_container parameter to CosmosMemoryClient and AsyncCosmosMemoryClient that routes turn-type documents to a dedicated container while keeping derived memories (facts, summaries, episodic, procedural) in the main container. Changes: - New _container_for_type() routes writes by memory_type - New _container_for_query() returns single container for callers that target derived memories only (get_memories, search_cosmos). Docstring clarifies it does NOT merge across containers. - New _containers_for_query() returns list of all relevant containers for multi-container merging (no ValueError on mixed types) - get_thread uses _containers_for_query and merges with chronological sort before recent_k slicing - ProcessingPipeline accepts cosmos_turns_container and queries both containers in extract_memories, generate_thread_summary, and generate_user_summary - Pipeline uses getattr guard for _turns_container for backward compat with tests constructing via __new__ - connect_cosmos() now accepts turns_container override parameter (symmetric with create_memory_store) - create_memory_store provisions turns container when configured - update_cosmos/delete_cosmos check both containers (main first, then turns fallback) - close() and _drain_cosmos_client() clear _turns_container_client - Async _init_pipeline builds a sync turns container handle for pipeline - Added inline comment noting change-feed triggers bound to main container will NOT fire for turns in dedicated turns container - Fully backward compatible: omit the param for single-container behavior Note: search_cosmos, add_tags, remove_tags, get_procedural_memories query the main container only by design — these operate on derived memories, not raw turns. Raw turns are retrieved via get_thread(). The change-feed Function App trigger requires a follow-up PR to support the turns container binding. Co-authored-by: TheovanKraay <TheovanKraay@users.noreply.github.com>
1 parent ebdcd18 commit ab61928

3 files changed

Lines changed: 374 additions & 18 deletions

File tree

agent_memory_toolkit/aio/cosmos_memory_client.py

Lines changed: 161 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
cosmos_key: Optional[str] = None,
9999
cosmos_database: Optional[str] = None,
100100
cosmos_container: Optional[str] = None,
101+
cosmos_turns_container: Optional[str] = None,
101102
cosmos_counter_container: Optional[str] = None,
102103
cosmos_lease_container: Optional[str] = None,
103104
cosmos_throughput_mode: Optional[str] = None,
@@ -131,6 +132,10 @@ def __init__(
131132
self._cosmos_key = cosmos_key
132133
self._cosmos_database = cosmos_database or "ai_memory"
133134
self._cosmos_container = cosmos_container or "memories"
135+
# None means use main container. Note: change-feed-based triggers
136+
# (e.g., Azure Functions) bound to the main container will NOT fire
137+
# for turns written to a dedicated turns container.
138+
self._cosmos_turns_container = cosmos_turns_container
134139
self._cosmos_counter_container = cosmos_counter_container or "counter"
135140
self._cosmos_lease_container = cosmos_lease_container or "leases"
136141
self._cosmos_throughput_mode = _resolve_cosmos_throughput_mode(cosmos_throughput_mode)
@@ -208,6 +213,7 @@ def __init__(
208213
# Internal Cosmos SDK handles
209214
self._cosmos_client: Any = None
210215
self._container_client: Any = None
216+
self._turns_container_client: Any = None # Separate container for turns (optional)
211217
self._counter_container_client: Any = None
212218

213219
# Composed sub-clients
@@ -274,6 +280,7 @@ async def close(self) -> None:
274280
await self._cosmos_client.close()
275281
self._cosmos_client = None
276282
self._container_client = None
283+
self._turns_container_client = None
277284
self._counter_container_client = None
278285
if self._sync_cosmos_client is not None:
279286
close = getattr(self._sync_cosmos_client, "close", None)
@@ -469,6 +476,7 @@ async def connect_cosmos(
469476
key: Optional[str] = None,
470477
database: Optional[str] = None,
471478
container: Optional[str] = None,
479+
turns_container: Optional[str] = None,
472480
) -> None:
473481
"""Establish an async connection to a Cosmos DB container.
474482
@@ -477,6 +485,9 @@ async def connect_cosmos(
477485
478486
Either *credential* (Entra ID) or *key* (account key) may be
479487
provided. *credential* takes precedence.
488+
489+
If *turns_container* is provided, it overrides the constructor's
490+
``cosmos_turns_container`` setting for this connection.
480491
"""
481492
self._cosmos_endpoint = endpoint or self._cosmos_endpoint
482493
if credential is not None:
@@ -486,6 +497,8 @@ async def connect_cosmos(
486497
self._cosmos_key = key
487498
self._cosmos_database = database or self._cosmos_database
488499
self._cosmos_container = container or self._cosmos_container
500+
if turns_container is not None:
501+
self._cosmos_turns_container = turns_container
489502

490503
_validate_connection(
491504
self._cosmos_endpoint,
@@ -505,6 +518,19 @@ async def connect_cosmos(
505518

506519
self._cosmos_client = client
507520
self._container_client = container_handle
521+
522+
# Connect turns container if configured separately
523+
if self._cosmos_turns_container:
524+
turns_handle = db.get_container_client(self._cosmos_turns_container)
525+
self._turns_container_client = turns_handle
526+
logger.info(
527+
"Async connected turns container: %s/%s",
528+
self._cosmos_database,
529+
self._cosmos_turns_container,
530+
)
531+
else:
532+
self._turns_container_client = None
533+
508534
self._init_pipeline()
509535
except Exception as exc:
510536
raise CosmosOperationError(f"Failed to connect to Cosmos DB (async): {exc}") from exc
@@ -519,6 +545,7 @@ async def create_memory_store(
519545
self,
520546
database: Optional[str] = None,
521547
container: Optional[str] = None,
548+
turns_container: Optional[str] = None,
522549
counter_container: Optional[str] = None,
523550
lease_container: Optional[str] = None,
524551
endpoint: Optional[str] = None,
@@ -560,6 +587,8 @@ async def create_memory_store(
560587
self._cosmos_key = key
561588
self._cosmos_database = database or self._cosmos_database
562589
self._cosmos_container = container or self._cosmos_container
590+
if turns_container is not None:
591+
self._cosmos_turns_container = turns_container
563592
self._cosmos_counter_container = counter_container or self._cosmos_counter_container
564593
self._cosmos_lease_container = lease_container or self._cosmos_lease_container
565594
self._cosmos_throughput_mode = _resolve_cosmos_throughput_mode(
@@ -630,6 +659,29 @@ async def create_memory_store(
630659
)
631660
self._cosmos_client = client
632661
self._container_client = container_handle
662+
663+
# Create and connect separate turns container if configured
664+
if self._cosmos_turns_container:
665+
turns_handle = await db.create_container_if_not_exists(
666+
**_build_container_kwargs(
667+
container_id=self._cosmos_turns_container,
668+
partition_key=partition_key,
669+
offer_throughput=offer_throughput,
670+
default_ttl=-1,
671+
indexing_policy=idx_policy,
672+
vector_embedding_policy=vec_policy,
673+
full_text_policy=ft_policy,
674+
)
675+
)
676+
self._turns_container_client = turns_handle
677+
logger.info(
678+
"Created turns container: %s/%s",
679+
self._cosmos_database,
680+
self._cosmos_turns_container,
681+
)
682+
else:
683+
self._turns_container_client = None
684+
633685
self._init_pipeline()
634686
except Exception as exc:
635687
raise CosmosOperationError(f"Failed to create memory store (async): {exc}") from exc
@@ -647,6 +699,68 @@ async def _require_cosmos(self) -> None:
647699
if self._container_client is None:
648700
raise CosmosNotConnectedError()
649701

702+
def _container_for_type(self, memory_type: str) -> Any:
703+
"""Return the appropriate container client based on memory type."""
704+
if memory_type == "turn" and self._turns_container_client is not None:
705+
return self._turns_container_client
706+
return self._container_client
707+
708+
def _container_for_query(self, memory_types: Optional[list[str]] = None) -> Any:
709+
"""Return a single container for a read query.
710+
711+
This helper is intended for callers that target a known single
712+
container (e.g., ``get_memories`` for derived memories, ``search_cosmos``
713+
for vector search). It does NOT merge across containers.
714+
715+
When a dedicated turns container is configured:
716+
717+
* turn-only queries → turns container
718+
* non-turn / unspecified / mixed queries → main memories container
719+
720+
Callers that need complete results across both containers (e.g.,
721+
``get_thread``) should use ``_containers_for_query()`` instead.
722+
723+
Note: ``get_memories(memory_types=None)`` returns derived memories
724+
only (facts, episodic, procedural, summaries). Raw turns are
725+
retrieved via ``get_thread()``.
726+
"""
727+
if not memory_types:
728+
return self._container_client
729+
730+
has_turn = any(t == "turn" for t in memory_types)
731+
has_not_turn = any(t != "turn" for t in memory_types)
732+
733+
if self._turns_container_client is not None:
734+
if has_turn and not has_not_turn:
735+
return self._turns_container_client
736+
if not has_turn:
737+
return self._container_client
738+
return self._container_client
739+
return self._container_client
740+
741+
def _containers_for_query(self, memory_types: Optional[list[str]] = None) -> list[Any]:
742+
"""Return candidate containers for a read query.
743+
744+
Unlike :meth:`_container_for_query` which returns a single container,
745+
this returns all containers that should be queried to get complete
746+
results, enabling callers to merge across containers.
747+
"""
748+
if self._turns_container_client is None:
749+
return [self._container_client]
750+
751+
if not memory_types:
752+
return [self._container_client, self._turns_container_client]
753+
754+
has_turn = any(t == "turn" for t in memory_types)
755+
has_not_turn = any(t != "turn" for t in memory_types)
756+
757+
if has_turn and has_not_turn:
758+
return [self._container_client, self._turns_container_client]
759+
760+
if has_turn:
761+
return [self._turns_container_client]
762+
return [self._container_client]
763+
650764
# ------------------------------------------------------------------
651765
# Cosmos DB CRUD operations (async)
652766
# ------------------------------------------------------------------
@@ -705,7 +819,8 @@ async def add_cosmos(
705819
)
706820

707821
try:
708-
await self._container_client.upsert_item(body=body)
822+
container = self._container_for_type(memory_type)
823+
await container.upsert_item(body=body)
709824
except Exception as exc:
710825
raise CosmosOperationError(f"Async upsert failed for record {record.id}: {exc}") from exc
711826
logger.info("add_cosmos id=%s role=%s type=%s", record.id, role, memory_type)
@@ -770,7 +885,7 @@ async def push_to_cosmos(self, batch_size: int = 25) -> None:
770885
len(to_embed_text),
771886
)
772887

773-
tasks = [self._container_client.upsert_item(body=b) for b in bodies]
888+
tasks = [self._container_for_type(b.get("type", "turn")).upsert_item(body=b) for b in bodies]
774889
try:
775890
await asyncio.gather(*tasks)
776891
except Exception as exc:
@@ -870,7 +985,8 @@ async def get_memories(
870985
logger.debug("async get_memories query: %s", query)
871986

872987
try:
873-
items_iter = self._container_client.query_items(
988+
container = self._container_for_query(memory_types)
989+
items_iter = container.query_items(
874990
query=query,
875991
parameters=parameters or None,
876992
)
@@ -914,6 +1030,14 @@ async def update_cosmos(
9141030
parameters=[{"name": "@id", "value": memory_id}],
9151031
)
9161032
docs = [item async for item in items_iter]
1033+
target_container = self._container_client
1034+
if not docs and self._turns_container_client is not None:
1035+
items_iter = self._turns_container_client.query_items(
1036+
query="SELECT * FROM c WHERE c.id = @id",
1037+
parameters=[{"name": "@id", "value": memory_id}],
1038+
)
1039+
docs = [item async for item in items_iter]
1040+
target_container = self._turns_container_client
9171041
except Exception as exc:
9181042
raise CosmosOperationError(f"async update query failed: {exc}") from exc
9191043

@@ -932,7 +1056,7 @@ async def update_cosmos(
9321056
doc["updated_at"] = datetime.now(timezone.utc).isoformat()
9331057

9341058
try:
935-
await self._container_client.replace_item(item=doc["id"], body=doc)
1059+
await target_container.replace_item(item=doc["id"], body=doc)
9361060
except Exception as exc:
9371061
raise CosmosOperationError(f"async update replace failed for {memory_id}: {exc}") from exc
9381062

@@ -962,14 +1086,29 @@ async def delete_cosmos(self, memory_id: str, thread_id: str, user_id: str) -> N
9621086
],
9631087
)
9641088
docs = [item async for item in items_iter]
1089+
target_container = self._container_client
1090+
if not docs and self._turns_container_client is not None:
1091+
items_iter = self._turns_container_client.query_items(
1092+
query=(
1093+
"SELECT TOP 1 c.id FROM c WHERE c.id = @id"
1094+
" AND c.thread_id = @thread_id AND c.user_id = @user_id"
1095+
),
1096+
parameters=[
1097+
{"name": "@id", "value": memory_id},
1098+
{"name": "@thread_id", "value": thread_id},
1099+
{"name": "@user_id", "value": user_id},
1100+
],
1101+
)
1102+
docs = [item async for item in items_iter]
1103+
target_container = self._turns_container_client
9651104
except Exception as exc:
9661105
raise CosmosOperationError(f"async delete lookup failed: {exc}") from exc
9671106

9681107
if not docs:
9691108
raise MemoryNotFoundError(memory_id=memory_id, user_id=user_id, thread_id=thread_id)
9701109

9711110
try:
972-
await self._container_client.delete_item(item=memory_id, partition_key=[user_id, thread_id])
1111+
await target_container.delete_item(item=memory_id, partition_key=[user_id, thread_id])
9731112
except Exception as exc:
9741113
raise CosmosOperationError(f"async delete failed for {memory_id}: {exc}") from exc
9751114

@@ -1126,8 +1265,14 @@ async def get_thread(
11261265
logger.debug("async get_thread query: %s", query)
11271266

11281267
try:
1129-
items_iter = self._container_client.query_items(query=query, parameters=parameters)
1130-
items = [item async for item in items_iter]
1268+
containers = self._containers_for_query(memory_types)
1269+
items = []
1270+
for container in containers:
1271+
items_iter = container.query_items(query=query, parameters=parameters)
1272+
items.extend([item async for item in items_iter])
1273+
# Re-sort merged results to preserve chronological ordering
1274+
if len(containers) > 1:
1275+
items.sort(key=lambda x: x.get("created_at", ""), reverse=True)
11311276
except Exception as exc:
11321277
raise CosmosOperationError(f"async get_thread query failed: {exc}") from exc
11331278

@@ -1317,6 +1462,7 @@ async def _drain_cosmos_client(self) -> None:
13171462
logger.warning("Failed to close prior async Cosmos client during reconnect", exc_info=True)
13181463
self._cosmos_client = None
13191464
self._container_client = None
1465+
self._turns_container_client = None
13201466
self._counter_container_client = None
13211467
self._pipeline = None
13221468
if not self._processor_explicit:
@@ -1368,10 +1514,18 @@ def _init_pipeline(self) -> None:
13681514
self._sync_embeddings_client = sync_embeddings
13691515

13701516
if sync_container is not None:
1517+
# Build sync turns container if configured
1518+
sync_turns_container = None
1519+
if self._cosmos_turns_container:
1520+
try:
1521+
sync_turns_container = sync_db.get_container_client(self._cosmos_turns_container)
1522+
except Exception:
1523+
logger.warning("Failed to create sync turns container for pipeline")
13711524
self._pipeline = ProcessingPipeline(
13721525
cosmos_container=sync_container,
13731526
chat_client=self._chat_client,
13741527
embeddings_client=sync_embeddings,
1528+
cosmos_turns_container=sync_turns_container,
13751529
)
13761530
self._warn_on_embedding_dim_mismatch(sync_container)
13771531

0 commit comments

Comments
 (0)