Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 161 additions & 7 deletions agent_memory_toolkit/aio/cosmos_memory_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
cosmos_key: Optional[str] = None,
cosmos_database: Optional[str] = None,
cosmos_container: Optional[str] = None,
cosmos_turns_container: Optional[str] = None,
cosmos_counter_container: Optional[str] = None,
cosmos_lease_container: Optional[str] = None,
cosmos_throughput_mode: Optional[str] = None,
Expand Down Expand Up @@ -131,6 +132,10 @@ def __init__(
self._cosmos_key = cosmos_key
self._cosmos_database = cosmos_database or "ai_memory"
self._cosmos_container = cosmos_container or "memories"
# None means use main container. Note: change-feed-based triggers
# (e.g., Azure Functions) bound to the main container will NOT fire
# for turns written to a dedicated turns container.
self._cosmos_turns_container = cosmos_turns_container
self._cosmos_counter_container = cosmos_counter_container or "counter"
self._cosmos_lease_container = cosmos_lease_container or "leases"
self._cosmos_throughput_mode = _resolve_cosmos_throughput_mode(cosmos_throughput_mode)
Expand Down Expand Up @@ -208,6 +213,7 @@ def __init__(
# Internal Cosmos SDK handles
self._cosmos_client: Any = None
self._container_client: Any = None
self._turns_container_client: Any = None # Separate container for turns (optional)
self._counter_container_client: Any = None
Comment on lines 215 to 217

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Both async \close()\ and _drain_cosmos_client()\ now set \self._turns_container_client = None\ alongside the other container handles.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Both async \close()\ and _drain_cosmos_client()\ now set \self._turns_container_client = None\ alongside the other container handles.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Both async \close()\ and _drain_cosmos_client()\ now set \self._turns_container_client = None\ alongside the other container handles.


# Composed sub-clients
Expand Down Expand Up @@ -274,6 +280,7 @@ async def close(self) -> None:
await self._cosmos_client.close()
self._cosmos_client = None
self._container_client = None
self._turns_container_client = None
self._counter_container_client = None
if self._sync_cosmos_client is not None:
close = getattr(self._sync_cosmos_client, "close", None)
Expand Down Expand Up @@ -469,6 +476,7 @@ async def connect_cosmos(
key: Optional[str] = None,
database: Optional[str] = None,
container: Optional[str] = None,
turns_container: Optional[str] = None,
) -> None:
"""Establish an async connection to a Cosmos DB container.

Expand All @@ -477,6 +485,9 @@ async def connect_cosmos(

Either *credential* (Entra ID) or *key* (account key) may be
provided. *credential* takes precedence.

If *turns_container* is provided, it overrides the constructor's
``cosmos_turns_container`` setting for this connection.
"""
self._cosmos_endpoint = endpoint or self._cosmos_endpoint
if credential is not None:
Expand All @@ -486,6 +497,8 @@ async def connect_cosmos(
self._cosmos_key = key
self._cosmos_database = database or self._cosmos_database
self._cosmos_container = container or self._cosmos_container
if turns_container is not None:
self._cosmos_turns_container = turns_container

_validate_connection(
self._cosmos_endpoint,
Expand All @@ -505,6 +518,19 @@ async def connect_cosmos(

self._cosmos_client = client
self._container_client = container_handle

# Connect turns container if configured separately
if self._cosmos_turns_container:
turns_handle = db.get_container_client(self._cosmos_turns_container)
self._turns_container_client = turns_handle
logger.info(
"Async connected turns container: %s/%s",
self._cosmos_database,
self._cosmos_turns_container,
)
else:
self._turns_container_client = None

self._init_pipeline()
except Exception as exc:
raise CosmosOperationError(f"Failed to connect to Cosmos DB (async): {exc}") from exc
Expand All @@ -519,6 +545,7 @@ async def create_memory_store(
self,
database: Optional[str] = None,
container: Optional[str] = None,
turns_container: Optional[str] = None,
counter_container: Optional[str] = None,
lease_container: Optional[str] = None,
endpoint: Optional[str] = None,
Expand Down Expand Up @@ -560,6 +587,8 @@ async def create_memory_store(
self._cosmos_key = key
self._cosmos_database = database or self._cosmos_database
self._cosmos_container = container or self._cosmos_container
if turns_container is not None:
self._cosmos_turns_container = turns_container
self._cosmos_counter_container = counter_container or self._cosmos_counter_container
self._cosmos_lease_container = lease_container or self._cosmos_lease_container
self._cosmos_throughput_mode = _resolve_cosmos_throughput_mode(
Expand Down Expand Up @@ -630,6 +659,29 @@ async def create_memory_store(
)
self._cosmos_client = client
self._container_client = container_handle

# Create and connect separate turns container if configured
if self._cosmos_turns_container:
turns_handle = await db.create_container_if_not_exists(
**_build_container_kwargs(
container_id=self._cosmos_turns_container,
partition_key=partition_key,
offer_throughput=offer_throughput,
default_ttl=-1,
indexing_policy=idx_policy,
vector_embedding_policy=vec_policy,
full_text_policy=ft_policy,
)
)
self._turns_container_client = turns_handle
logger.info(
"Created turns container: %s/%s",
self._cosmos_database,
self._cosmos_turns_container,
)
else:
self._turns_container_client = None

self._init_pipeline()
except Exception as exc:
raise CosmosOperationError(f"Failed to create memory store (async): {exc}") from exc
Expand All @@ -647,6 +699,68 @@ async def _require_cosmos(self) -> None:
if self._container_client is None:
raise CosmosNotConnectedError()

def _container_for_type(self, memory_type: str) -> Any:
"""Return the appropriate container client based on memory type."""
if memory_type == "turn" and self._turns_container_client is not None:
return self._turns_container_client
return self._container_client

def _container_for_query(self, memory_types: Optional[list[str]] = None) -> Any:
"""Return a single container for a read query.

This helper is intended for callers that target a known single
container (e.g., ``get_memories`` for derived memories, ``search_cosmos``
for vector search). It does NOT merge across containers.

When a dedicated turns container is configured:

* turn-only queries → turns container
* non-turn / unspecified / mixed queries → main memories container

Callers that need complete results across both containers (e.g.,
``get_thread``) should use ``_containers_for_query()`` instead.

Note: ``get_memories(memory_types=None)`` returns derived memories
only (facts, episodic, procedural, summaries). Raw turns are
retrieved via ``get_thread()``.
"""
if not memory_types:
return self._container_client

has_turn = any(t == "turn" for t in memory_types)
has_not_turn = any(t != "turn" for t in memory_types)

if self._turns_container_client is not None:
if has_turn and not has_not_turn:
return self._turns_container_client
if not has_turn:
return self._container_client
return self._container_client
return self._container_client

def _containers_for_query(self, memory_types: Optional[list[str]] = None) -> list[Any]:
"""Return candidate containers for a read query.

Unlike :meth:`_container_for_query` which returns a single container,
this returns all containers that should be queried to get complete
results, enabling callers to merge across containers.
"""
if self._turns_container_client is None:
return [self._container_client]

if not memory_types:
return [self._container_client, self._turns_container_client]

has_turn = any(t == "turn" for t in memory_types)
has_not_turn = any(t != "turn" for t in memory_types)

if has_turn and has_not_turn:
return [self._container_client, self._turns_container_client]

if has_turn:
return [self._turns_container_client]
return [self._container_client]

# ------------------------------------------------------------------
# Cosmos DB CRUD operations (async)
# ------------------------------------------------------------------
Expand Down Expand Up @@ -705,7 +819,8 @@ async def add_cosmos(
)

try:
await self._container_client.upsert_item(body=body)
container = self._container_for_type(memory_type)
await container.upsert_item(body=body)
except Exception as exc:
raise CosmosOperationError(f"Async upsert failed for record {record.id}: {exc}") from exc
logger.info("add_cosmos id=%s role=%s type=%s", record.id, role, memory_type)
Expand Down Expand Up @@ -770,7 +885,7 @@ async def push_to_cosmos(self, batch_size: int = 25) -> None:
len(to_embed_text),
)

tasks = [self._container_client.upsert_item(body=b) for b in bodies]
tasks = [self._container_for_type(b.get("type", "turn")).upsert_item(body=b) for b in bodies]
try:
await asyncio.gather(*tasks)
except Exception as exc:
Expand Down Expand Up @@ -870,7 +985,8 @@ async def get_memories(
logger.debug("async get_memories query: %s", query)

try:
items_iter = self._container_client.query_items(
container = self._container_for_query(memory_types)
items_iter = container.query_items(
query=query,
parameters=parameters or None,
)
Expand Down Expand Up @@ -914,6 +1030,14 @@ async def update_cosmos(
parameters=[{"name": "@id", "value": memory_id}],
)
docs = [item async for item in items_iter]
target_container = self._container_client
if not docs and self._turns_container_client is not None:
items_iter = self._turns_container_client.query_items(
query="SELECT * FROM c WHERE c.id = @id",
parameters=[{"name": "@id", "value": memory_id}],
)
docs = [item async for item in items_iter]
target_container = self._turns_container_client
except Exception as exc:
raise CosmosOperationError(f"async update query failed: {exc}") from exc

Expand All @@ -932,7 +1056,7 @@ async def update_cosmos(
doc["updated_at"] = datetime.now(timezone.utc).isoformat()

try:
await self._container_client.replace_item(item=doc["id"], body=doc)
await target_container.replace_item(item=doc["id"], body=doc)
except Exception as exc:
raise CosmosOperationError(f"async update replace failed for {memory_id}: {exc}") from exc

Expand Down Expand Up @@ -962,14 +1086,29 @@ async def delete_cosmos(self, memory_id: str, thread_id: str, user_id: str) -> N
],
)
docs = [item async for item in items_iter]
target_container = self._container_client
if not docs and self._turns_container_client is not None:
items_iter = self._turns_container_client.query_items(
query=(
"SELECT TOP 1 c.id FROM c WHERE c.id = @id"
" AND c.thread_id = @thread_id AND c.user_id = @user_id"
),
parameters=[
{"name": "@id", "value": memory_id},
{"name": "@thread_id", "value": thread_id},
{"name": "@user_id", "value": user_id},
],
)
docs = [item async for item in items_iter]
target_container = self._turns_container_client
except Exception as exc:
raise CosmosOperationError(f"async delete lookup failed: {exc}") from exc

if not docs:
raise MemoryNotFoundError(memory_id=memory_id, user_id=user_id, thread_id=thread_id)

try:
await self._container_client.delete_item(item=memory_id, partition_key=[user_id, thread_id])
await target_container.delete_item(item=memory_id, partition_key=[user_id, thread_id])
except Exception as exc:
raise CosmosOperationError(f"async delete failed for {memory_id}: {exc}") from exc

Expand Down Expand Up @@ -1126,8 +1265,14 @@ async def get_thread(
logger.debug("async get_thread query: %s", query)

try:
items_iter = self._container_client.query_items(query=query, parameters=parameters)
items = [item async for item in items_iter]
containers = self._containers_for_query(memory_types)
items = []
for container in containers:
items_iter = container.query_items(query=query, parameters=parameters)
items.extend([item async for item in items_iter])
# Re-sort merged results to preserve chronological ordering
if len(containers) > 1:
items.sort(key=lambda x: x.get("created_at", ""), reverse=True)
except Exception as exc:
raise CosmosOperationError(f"async get_thread query failed: {exc}") from exc

Expand Down Expand Up @@ -1317,6 +1462,7 @@ async def _drain_cosmos_client(self) -> None:
logger.warning("Failed to close prior async Cosmos client during reconnect", exc_info=True)
self._cosmos_client = None
self._container_client = None
self._turns_container_client = None
self._counter_container_client = None
self._pipeline = None
if not self._processor_explicit:
Expand Down Expand Up @@ -1368,10 +1514,18 @@ def _init_pipeline(self) -> None:
self._sync_embeddings_client = sync_embeddings

if sync_container is not None:
# Build sync turns container if configured
sync_turns_container = None
if self._cosmos_turns_container:
try:
sync_turns_container = sync_db.get_container_client(self._cosmos_turns_container)
except Exception:
logger.warning("Failed to create sync turns container for pipeline")
self._pipeline = ProcessingPipeline(
cosmos_container=sync_container,
chat_client=self._chat_client,
embeddings_client=sync_embeddings,
cosmos_turns_container=sync_turns_container,
)
self._warn_on_embedding_dim_mismatch(sync_container)

Expand Down
Loading
Loading