Skip to content

Commit c03a1f1

Browse files
Aayush KatariaAayush Kataria
authored andcommitted
Resolving comments
1 parent 9207989 commit c03a1f1

8 files changed

Lines changed: 523 additions & 142 deletions

File tree

.env.template

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
# This is a template for the .env file. Copy this file to .env and fill in the values for your accounts.
2+
3+
# ---- Cosmos DB ----
24
COSMOS_DB_ENDPOINT=https://<your-account>.documents.azure.com:443/
5+
# COSMOS_DB__accountEndpoint is required for the Azure Functions change feed trigger
6+
# (identity-based connection). Set it to the same value as COSMOS_DB_ENDPOINT.
7+
COSMOS_DB__accountEndpoint=https://<your-account>.documents.azure.com:443/
38
COSMOS_DB_DATABASE=ai_memory
49
COSMOS_DB_CONTAINER=memories
10+
COSMOS_DB_COUNTERS_CONTAINER=counter
11+
COSMOS_DB_LEASE_CONTAINER=leases
512
COSMOS_DB_AUTOSCALE_MAX_RU=1000
613

14+
# ---- Change Feed Thresholds (set to 0 to disable) ----
15+
THREAD_SUMMARY_EVERY_N=0
16+
FACT_EXTRACTION_EVERY_N=0
17+
USER_SUMMARY_EVERY_N=0
18+
19+
# ---- AI Foundry / Azure OpenAI ----
720
AI_FOUNDRY_ENDPOINT=https://<your-account>.openai.azure.com/
821
AI_FOUNDRY_API_KEY=
922
EMBEDDING_MODEL=text-embedding-3-large
@@ -14,5 +27,6 @@ FULL_TEXT_LANGUAGE=en-US
1427

1528
LLM_MODEL=<your-model-deployment>
1629

30+
# ---- Azure Durable Functions ----
1731
ADF_ENDPOINT=http://localhost:7071/api
1832
ADF_KEY=

Samples/Demo.ipynb

Lines changed: 14 additions & 81 deletions
Large diffs are not rendered by default.

agent_memory_toolkit/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class MemoryType(str, Enum):
3636
summary = "summary"
3737
fact = "fact"
3838
user_summary = "user_summary"
39-
counter = "counter"
4039

4140

4241
# ---------------------------------------------------------------------------

azure_functions/activities.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,13 @@ async def _get_cosmos_database_client():
126126
"""Return the Cosmos DB database client, connecting on first call."""
127127
global _cosmos_client, _cosmos_database_client
128128
if _cosmos_database_client is None:
129-
endpoint = os.environ["COSMOS_DB__accountEndpoint"]
129+
endpoint = os.environ.get("COSMOS_DB__accountEndpoint") or os.environ.get("COSMOS_DB_ENDPOINT")
130+
if not endpoint:
131+
raise ValueError(
132+
"Cosmos DB endpoint not configured. "
133+
"Set COSMOS_DB__accountEndpoint (required for change feed trigger) "
134+
"or COSMOS_DB_ENDPOINT."
135+
)
130136
database = os.environ["COSMOS_DB_DATABASE"]
131137
logger.info(
132138
"Connecting to Cosmos DB endpoint=%s database=%s",

azure_functions/function_app.py

Lines changed: 161 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
Required Cosmos DB containers:
5050
5151
- ``memories`` – existing container for memory documents
52+
- ``counter`` – dedicated container for change feed counter documents
53+
(configurable via ``COSMOS_DB_COUNTERS_CONTAINER``)
5254
- ``leases`` – auto-created by the trigger for change feed checkpointing
5355
"""
5456

@@ -59,6 +61,7 @@
5961

6062
import azure.functions as func
6163
import azure.durable_functions as df
64+
from azure.core import MatchConditions
6265
from azure.cosmos.exceptions import CosmosHttpResponseError, CosmosResourceNotFoundError
6366

6467
from activities import _get_cosmos_counter_container, bp as activities_bp
@@ -83,11 +86,49 @@ def _utc_now_iso() -> str:
8386
return datetime.now(timezone.utc).isoformat()
8487

8588

86-
async def increment_counter_by(counter_id: str, user_id: str, thread_id: str, count: int) -> tuple[int, int]:
89+
def _parse_threshold(name: str) -> int:
90+
"""Parse an integer threshold from an environment variable.
91+
92+
Returns 0 (disabled) if the variable is missing, empty, or not a valid
93+
integer, and logs a warning so misconfigurations are visible.
94+
"""
95+
raw = os.environ.get(name, "0")
96+
try:
97+
return int(raw)
98+
except (ValueError, TypeError):
99+
logger.warning(
100+
"Invalid value for %s=%r, defaulting to 0 (disabled)", name, raw,
101+
)
102+
return 0
103+
104+
105+
async def increment_counter_by(
106+
counter_id: str,
107+
user_id: str,
108+
thread_id: str,
109+
count: int,
110+
*,
111+
batch_max_lsn: int | None = None,
112+
) -> tuple[int, int]:
87113
"""Atomically increment a counter document by *count* using ETag concurrency.
88114
89115
Returns ``(old_count, new_count)``. Creates the counter document if it
90116
does not exist. Retries up to 3 times on ETag conflicts (HTTP 412).
117+
118+
If *batch_max_lsn* is provided, the counter stores it alongside the
119+
pre-increment count. On a change-feed retry (same batch replayed),
120+
the function detects the duplicate via LSN comparison and returns the
121+
cached ``(pre_batch_count, current_count)`` **without** writing,
122+
preserving threshold-crossing semantics for the caller.
123+
124+
.. note::
125+
126+
LSN-based replay detection is perfect for **thread-scoped** counters
127+
(single logical partition → monotonic LSNs). For **user-scoped**
128+
counters that aggregate across partitions, it handles the common
129+
single-partition-range retry but may not detect cross-partition
130+
interleaving. Deterministic orchestration instance IDs provide
131+
an additional safety net against duplicate orchestration starts.
91132
"""
92133
container = await _get_cosmos_counter_container()
93134
max_retries = 3
@@ -108,12 +149,27 @@ async def increment_counter_by(counter_id: str, user_id: str, thread_id: str, co
108149
except CosmosHttpResponseError:
109150
raise
110151

152+
# ---- Replay detection via LSN ----
153+
if (
154+
batch_max_lsn is not None
155+
and existing_doc is not None
156+
and existing_doc.get("last_batch_lsn") == batch_max_lsn
157+
):
158+
replay_old = existing_doc.get("last_batch_old_count", old_count)
159+
logger.info(
160+
"Counter replay detected counter_id=%s lsn=%s, returning cached result",
161+
counter_id, batch_max_lsn,
162+
)
163+
return (replay_old, old_count)
164+
111165
new_count = old_count + count
112166
new_doc = {
113167
"id": counter_id,
114168
"user_id": user_id,
115169
"thread_id": thread_id,
116170
"count": new_count,
171+
"last_batch_lsn": batch_max_lsn,
172+
"last_batch_old_count": old_count,
117173
"created_at": existing_doc.get("created_at", _utc_now_iso()) if existing_doc else _utc_now_iso(),
118174
"updated_at": _utc_now_iso(),
119175
}
@@ -124,10 +180,28 @@ async def increment_counter_by(counter_id: str, user_id: str, thread_id: str, co
124180
await container.upsert_item(
125181
body=new_doc,
126182
etag=etag,
127-
match_condition="IfMatch",
183+
match_condition=MatchConditions.IfNotModified,
128184
)
129185
else:
130-
await container.upsert_item(body=new_doc)
186+
# First-time creation: use create_item to avoid last-writer-wins
187+
# race when multiple Function instances see 404 concurrently.
188+
try:
189+
await container.create_item(body=new_doc)
190+
except CosmosHttpResponseError as create_exc:
191+
if create_exc.status_code == 409 and attempt < max_retries - 1:
192+
# Another instance created it first — retry with read-modify-write
193+
logger.warning(
194+
"Counter create conflict counter_id=%s attempt=%d/%d, retrying",
195+
counter_id, attempt + 1, max_retries,
196+
)
197+
continue
198+
if create_exc.status_code == 409:
199+
logger.warning(
200+
"Counter create conflict exhausted retries counter_id=%s, skipping",
201+
counter_id,
202+
)
203+
return (old_count, old_count)
204+
raise
131205
return (old_count, new_count)
132206
except CosmosHttpResponseError as exc:
133207
if exc.status_code == 412 and attempt < max_retries - 1:
@@ -330,16 +404,18 @@ async def process_changefeed_batch(documents: list[dict], starter) -> None:
330404
Extracted from the trigger function so it can be tested without the
331405
Durable Functions middleware.
332406
"""
333-
n_thread = int(os.environ.get("THREAD_SUMMARY_EVERY_N", "0"))
334-
n_facts = int(os.environ.get("FACT_EXTRACTION_EVERY_N", "0"))
335-
n_user = int(os.environ.get("USER_SUMMARY_EVERY_N", "0"))
407+
n_thread = _parse_threshold("THREAD_SUMMARY_EVERY_N")
408+
n_facts = _parse_threshold("FACT_EXTRACTION_EVERY_N")
409+
n_user = _parse_threshold("USER_SUMMARY_EVERY_N")
336410

337411
if n_thread == 0 and n_facts == 0 and n_user == 0:
338412
return # all processing disabled
339413

340414
# ---- Step 1: Filter to turns, group by scope ----
341415
thread_counts: dict[tuple[str, str], int] = defaultdict(int)
342416
user_counts: dict[str, int] = defaultdict(int)
417+
thread_max_lsn: dict[tuple[str, str], int] = {}
418+
user_max_lsn: dict[str, int] = {}
343419

344420
for doc in documents:
345421
# Counter writes land in the separate counter container, so only raw
@@ -356,6 +432,13 @@ async def process_changefeed_batch(documents: list[dict], starter) -> None:
356432
thread_counts[(user_id, thread_id)] += 1
357433
user_counts[user_id] += 1
358434

435+
# Track max _lsn per scope for replay detection
436+
lsn = doc.get("_lsn")
437+
if lsn is not None:
438+
key = (user_id, thread_id)
439+
thread_max_lsn[key] = max(thread_max_lsn.get(key, 0), lsn)
440+
user_max_lsn[user_id] = max(user_max_lsn.get(user_id, 0), lsn)
441+
359442
thread_counters_enabled = (n_thread > 0 or n_facts > 0)
360443
user_counters_enabled = (n_user > 0)
361444
enabled_thread_groups = len(thread_counts) if thread_counters_enabled else 0
@@ -370,59 +453,102 @@ async def process_changefeed_batch(documents: list[dict], starter) -> None:
370453
)
371454

372455
# ---- Step 2: Thread-scoped counters and threshold checks ----
456+
orchestration_errors: list[Exception] = []
457+
373458
if thread_counters_enabled:
374459
for (user_id, thread_id), batch_count in thread_counts.items():
460+
lsn = thread_max_lsn.get((user_id, thread_id))
375461
old_count, new_count = await increment_counter_by(
376462
f"thread_counter_{user_id}_{thread_id}", user_id, thread_id, batch_count,
463+
batch_max_lsn=lsn,
377464
)
378465

379466
if n_thread > 0 and crosses_threshold(old_count, new_count, n_thread):
467+
bucket = new_count // n_thread
468+
instance_id = f"ts_{user_id}_{thread_id}_{bucket}"
380469
logger.info(
381-
"on_memory_change: triggering thread_summary user_id=%s thread_id=%s count=%d",
382-
user_id, thread_id, new_count,
383-
)
384-
await starter.start_new(
385-
"memory_orchestrator",
386-
client_input={
387-
"thread_summary_only": True,
388-
"user_id": user_id,
389-
"thread_id": thread_id,
390-
},
470+
"on_memory_change: triggering thread_summary user_id=%s thread_id=%s count=%d instance=%s",
471+
user_id, thread_id, new_count, instance_id,
391472
)
473+
try:
474+
await starter.start_new(
475+
"memory_orchestrator",
476+
instance_id=instance_id,
477+
client_input={
478+
"thread_summary_only": True,
479+
"user_id": user_id,
480+
"thread_id": thread_id,
481+
},
482+
)
483+
except Exception as exc:
484+
logger.exception(
485+
"Failed to start thread_summary orchestration user_id=%s thread_id=%s",
486+
user_id, thread_id,
487+
)
488+
orchestration_errors.append(exc)
392489

393490
if n_facts > 0 and crosses_threshold(old_count, new_count, n_facts):
491+
bucket = new_count // n_facts
492+
instance_id = f"ef_{user_id}_{thread_id}_{bucket}"
394493
logger.info(
395-
"on_memory_change: triggering extract_facts user_id=%s thread_id=%s count=%d",
396-
user_id, thread_id, new_count,
397-
)
398-
await starter.start_new(
399-
"memory_orchestrator",
400-
client_input={
401-
"extract_facts_only": True,
402-
"user_id": user_id,
403-
"thread_id": thread_id,
404-
},
494+
"on_memory_change: triggering extract_facts user_id=%s thread_id=%s count=%d instance=%s",
495+
user_id, thread_id, new_count, instance_id,
405496
)
497+
try:
498+
await starter.start_new(
499+
"memory_orchestrator",
500+
instance_id=instance_id,
501+
client_input={
502+
"extract_facts_only": True,
503+
"user_id": user_id,
504+
"thread_id": thread_id,
505+
},
506+
)
507+
except Exception as exc:
508+
logger.exception(
509+
"Failed to start extract_facts orchestration user_id=%s thread_id=%s",
510+
user_id, thread_id,
511+
)
512+
orchestration_errors.append(exc)
406513

407514
# ---- Step 3: User-scoped counters and threshold checks ----
408515
if user_counters_enabled:
409516
for user_id, batch_count in user_counts.items():
517+
lsn = user_max_lsn.get(user_id)
410518
old_count, new_count = await increment_counter_by(
411519
f"user_counter_{user_id}", user_id, USER_COUNTER_THREAD_ID, batch_count,
520+
batch_max_lsn=lsn,
412521
)
413522

414523
if crosses_threshold(old_count, new_count, n_user):
524+
bucket = new_count // n_user
525+
instance_id = f"us_{user_id}_{bucket}"
415526
logger.info(
416-
"on_memory_change: triggering user_summary user_id=%s count=%d",
417-
user_id, new_count,
418-
)
419-
await starter.start_new(
420-
"memory_orchestrator",
421-
client_input={
422-
"user_summary_only": True,
423-
"user_id": user_id,
424-
},
527+
"on_memory_change: triggering user_summary user_id=%s count=%d instance=%s",
528+
user_id, new_count, instance_id,
425529
)
530+
try:
531+
await starter.start_new(
532+
"memory_orchestrator",
533+
instance_id=instance_id,
534+
client_input={
535+
"user_summary_only": True,
536+
"user_id": user_id,
537+
},
538+
)
539+
except Exception as exc:
540+
logger.exception(
541+
"Failed to start user_summary orchestration user_id=%s",
542+
user_id,
543+
)
544+
orchestration_errors.append(exc)
545+
546+
# Re-raise so the change feed batch retries and thresholds re-fire
547+
if orchestration_errors:
548+
raise RuntimeError(
549+
f"Failed to start {len(orchestration_errors)} orchestration(s); "
550+
"raising to retry the change feed batch"
551+
) from orchestration_errors[0]
426552
@df_app.cosmos_db_trigger(
427553
arg_name="documents",
428554
connection="COSMOS_DB",

tests/integration/test_changefeed_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def cosmos_clients():
3838
from azure.cosmos import CosmosClient
3939
from azure.identity import DefaultAzureCredential
4040

41-
endpoint = os.environ["COSMOS_DB__accountEndpoint"]
41+
endpoint = os.environ.get("COSMOS_DB__accountEndpoint") or os.environ.get("COSMOS_DB_ENDPOINT")
4242
database_name = os.environ.get("COSMOS_DB_DATABASE", "ai_memory")
4343
memories_container_name = os.environ.get("COSMOS_DB_CONTAINER", "memories")
4444
counter_container_name = os.environ.get("COSMOS_DB_COUNTERS_CONTAINER", "counter")

0 commit comments

Comments
 (0)