From a2d4a1e78932c0682bb824dc6493f895a1338846 Mon Sep 17 00:00:00 2001 From: ved015 Date: Wed, 3 Jun 2026 20:41:51 +0530 Subject: [PATCH 1/4] Add v2 S3 original storage and hybrid search --- .env.example | 30 +++ src/api/routes/v2/activities.py | 12 + src/api/routes/v2/memory.py | 122 ++++++++- src/api/routes/v2/workflows.py | 111 ++++++-- src/api/schemas.py | 36 ++- src/config/settings.py | 99 +++++++ src/storage/original.py | 402 ++++++++++++++++++++++++++++ tests/api/test_v2_hybrid_search.py | 90 +++++++ tests/unit/test_original_storage.py | 98 +++++++ 9 files changed, 980 insertions(+), 20 deletions(-) create mode 100644 src/storage/original.py create mode 100644 tests/api/test_v2_hybrid_search.py create mode 100644 tests/unit/test_original_storage.py diff --git a/.env.example b/.env.example index dca76cfb..1b45ea5b 100644 --- a/.env.example +++ b/.env.example @@ -22,6 +22,36 @@ PINECONE_REGION=us-east-1 # EMBEDDING_MODEL=all-MiniLM-L6-v2 EMBEDDING_MODEL=gemini-embedding-001 +# ============================================================================= +# Original Storage + v2 Hybrid Search (optional, v2 only) +# ============================================================================= +ORIGINAL_STORAGE_ENABLED=false +ORIGINAL_STORAGE_PROVIDER=s3 +ORIGINAL_STORAGE_FAIL_CLOSED=false +ORIGINAL_STORAGE_TIMEOUT_SECONDS=180 + +ORIGINAL_S3_BUCKET= +ORIGINAL_S3_REGION=us-east-1 +ORIGINAL_S3_PREFIX=originals +ORIGINAL_S3_ENDPOINT_URL= +ORIGINAL_S3_KMS_KEY_ID= +ORIGINAL_S3_MULTIPART_THRESHOLD_BYTES=8388608 +ORIGINAL_S3_MULTIPART_CHUNK_BYTES=8388608 + +ORIGINAL_CHUNK_SIZE_TOKENS=350 +ORIGINAL_CHUNK_OVERLAP_TOKENS=40 +ORIGINAL_INDEX_BATCH_SIZE=64 +ORIGINAL_EMBED_CONCURRENCY=4 +ORIGINAL_INDEX_CONCURRENCY=2 +ORIGINAL_BATCH_ITEM_CONCURRENCY=3 +ORIGINAL_MAX_BYTES=10485760 +ORIGINAL_INCLUDE_AGENT_RESPONSE=true +ORIGINAL_INCLUDE_IMAGE_URL=false + +HYBRID_SEARCH_MEMORY_TOP_K=10 +HYBRID_SEARCH_ORIGINAL_TOP_K=10 +HYBRID_SEARCH_MIN_SCORE=0.0 + # ============================================================================= # Database Configuration # ============================================================================= diff --git a/src/api/routes/v2/activities.py b/src/api/routes/v2/activities.py index 930128d7..696bfe48 100644 --- a/src/api/routes/v2/activities.py +++ b/src/api/routes/v2/activities.py @@ -12,6 +12,7 @@ from src.billing.context import use_billing_context from src.billing.service import commit_job_billing, release_job_billing from src.jobs.durable import get_default_job_store +from src.storage.original import preserve_original try: # pragma: no cover - no-op fallback keeps imports working without SDK. from temporalio import activity @@ -179,6 +180,16 @@ async def memory_run_pipeline_activity(payload: Dict[str, Any]) -> Dict[str, Any return await memory_v1._run_ingest_payload(payload, payload["user_id"]) +@activity.defn +async def memory_store_original_activity(payload: Dict[str, Any]) -> Dict[str, Any]: + pipeline = get_ingest_pipeline() + return await preserve_original( + payload, + vector_store=pipeline.vector_store, + embed_fn=pipeline.embed_fn, + ) + + @activity.defn async def memory_scrape_activity(payload: Dict[str, Any]) -> Dict[str, Any]: def _run_scrape() -> Dict[str, Any]: @@ -286,6 +297,7 @@ async def scanner_phase2_activity(payload: Dict[str, Any]) -> Dict[str, Any]: memory_classify_activity, memory_domain_activity, memory_run_pipeline_activity, + memory_store_original_activity, memory_scrape_activity, scanner_scan_activity, scanner_phase2_activity, diff --git a/src/api/routes/v2/memory.py b/src/api/routes/v2/memory.py index f6c68b29..2b559d05 100644 --- a/src/api/routes/v2/memory.py +++ b/src/api/routes/v2/memory.py @@ -7,7 +7,12 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse -from src.api.dependencies import enforce_rate_limit, require_api_key, require_ready +from src.api.dependencies import ( + enforce_rate_limit, + get_retrieval_pipeline, + require_api_key, + require_ready, +) from src.api.routes import memory as memory_v1 from src.api.routes.v2.shared import ( _error, @@ -18,10 +23,20 @@ read_user_job, ) from src.api.routes.v2.temporal_client import start_job_workflow -from src.api.schemas import APIResponse, BatchIngestRequest, IngestRequest, ScrapeRequest, StatusEnum +from src.api.schemas import ( + APIResponse, + BatchIngestRequest, + HybridSearchRequest, + HybridSearchResponse, + IngestRequest, + ScrapeRequest, + SourceRecord, + StatusEnum, +) from src.billing import InsufficientCredits, get_default_billing_service from src.config import settings from src.jobs.durable import QUEUED, get_default_job_store, idempotency_key, new_attempt_id, stable_hash +from src.storage.original import ORIGINAL_CHUNK_DOMAIN, original_config_snapshot router = APIRouter( prefix="/v2/memory", @@ -44,6 +59,18 @@ def _durable_job_id(job_type: str, fields: Dict[str, Any]) -> str: return f"{job_type}:{idempotency_key(job_type, fields)}" +def _attach_original_storage_config(payload: Dict[str, Any]) -> None: + payload["original_storage_enabled"] = bool(settings.original_storage_enabled) + payload["original_storage_fail_closed"] = bool(settings.original_storage_fail_closed) + payload["original_storage_timeout_seconds"] = float( + settings.original_storage_timeout_seconds + ) + payload["original_batch_item_concurrency"] = int( + settings.original_batch_item_concurrency + ) + payload["original_config"] = original_config_snapshot() + + class WorkflowStartFailed(RuntimeError): def __init__(self, job: Dict[str, Any], error: str) -> None: super().__init__(error) @@ -122,6 +149,7 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De payload = req.model_dump() payload["user_id"] = user_id payload["timeout_seconds"] = float(settings.memory_ingest_timeout_seconds) + _attach_original_storage_config(payload) idempotency_fields = { "user_id": user_id, "org_id": payload.get("org_id", "default"), @@ -132,6 +160,7 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De "image_url": req.image_url, "effort_level": req.effort_level, }), + "original_storage_enabled": bool(settings.original_storage_enabled), } job_id = _durable_job_id("memory_ingest", idempotency_fields) billing_service = get_default_billing_service() @@ -217,9 +246,11 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user min(len(req.items) * float(settings.memory_ingest_timeout_seconds), 3600.0), ), } + _attach_original_storage_config(payload) idempotency_fields = { "user_id": user_id, "content_hash": _content_hash({"items": items}), + "original_storage_enabled": bool(settings.original_storage_enabled), } job_id = _durable_job_id("memory_batch_ingest", idempotency_fields) billing_service = get_default_billing_service() @@ -278,6 +309,93 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user return _error(request, str(exc), 500, elapsed_ms(start)) +async def _search_original_chunks( + query: str, + user_id: str, + top_k: int, +) -> list[SourceRecord]: + pipeline = get_retrieval_pipeline() + raw = await pipeline.vector_store.search_by_text( + query_text=query, + top_k=top_k, + filters={"user_id": user_id, "domain": ORIGINAL_CHUNK_DOMAIN}, + ) + results: list[SourceRecord] = [] + for item in raw: + score = float(item.score or 0.0) + if score < float(settings.hybrid_search_min_score): + continue + results.append( + SourceRecord( + domain=ORIGINAL_CHUNK_DOMAIN, + content=item.content, + score=round(score, 3), + metadata={"id": item.id, **item.metadata}, + ) + ) + return results + + +@router.post( + "/hybrid-search", + response_model=APIResponse, + summary="v2-only hybrid search across extracted memories and original chunks", +) +async def hybrid_search_memory_v2( + req: HybridSearchRequest, + request: Request, + user: dict = Depends(require_api_key), +): + start = time.perf_counter() + pipeline = get_retrieval_pipeline() + user_id = memory_v1._current_user_id(user, req.user_id) + memory_top_k = req.memory_top_k or int(settings.hybrid_search_memory_top_k) + original_top_k = req.original_top_k or int(settings.hybrid_search_original_top_k) + + try: + memory_results: list[SourceRecord] = [] + if "profile" in req.domains: + memory_results.extend(memory_v1._search_profile(pipeline, user_id)) + if "temporal" in req.domains: + memory_results.extend( + memory_v1._search_temporal( + pipeline, + req.query, + user_id, + memory_top_k, + ) + ) + if "summary" in req.domains: + memory_results.extend( + await memory_v1._search_summary( + pipeline, + req.query, + user_id, + memory_top_k, + ) + ) + + original_chunks: list[SourceRecord] = [] + if req.include_original_chunks and settings.original_storage_enabled: + original_chunks = await _search_original_chunks( + req.query, + user_id, + original_top_k, + ) + + all_results = memory_results + original_chunks + data = HybridSearchResponse( + memory_results=memory_results, + original_chunks=original_chunks, + results=all_results, + total=len(all_results), + original_storage_enabled=bool(settings.original_storage_enabled), + ) + return _wrap(request, data, elapsed_ms(start)) + except Exception as exc: + return _error(request, str(exc), 500, elapsed_ms(start)) + + @scrape_router.post("/scrape", response_model=APIResponse, summary="Start an async durable scrape job") async def scrape_chat_link_v2(req: ScrapeRequest, request: Request): start = time.perf_counter() diff --git a/src/api/routes/v2/workflows.py b/src/api/routes/v2/workflows.py index 73f5207f..62d5f987 100644 --- a/src/api/routes/v2/workflows.py +++ b/src/api/routes/v2/workflows.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from datetime import timedelta from typing import Any, Dict, List @@ -54,6 +55,41 @@ async def _execute(name: str, arg: Any, timeout_seconds: float) -> Any: ) +def _original_enabled(payload: Dict[str, Any]) -> bool: + return bool(payload.get("original_storage_enabled")) + + +def _original_timeout(payload: Dict[str, Any]) -> float: + return float(payload.get("original_storage_timeout_seconds") or 180.0) + + +def _start_original_task(job_id: str, payload: Dict[str, Any]): + if not _original_enabled(payload): + return None + return asyncio.create_task( + _execute( + "memory_store_original_activity", + {**payload, "job_id": job_id}, + _original_timeout(payload), + ) + ) + + +async def _await_original_task(task, payload: Dict[str, Any]) -> Dict[str, Any]: + if task is None: + return {"status": "disabled", "indexed_chunks": 0} + try: + return await task + except Exception as exc: + if bool(payload.get("original_storage_fail_closed")): + raise + return { + "status": "failed", + "error": str(exc) or exc.__class__.__name__, + "indexed_chunks": 0, + } + + async def _mark_dead(job_id: str, exc: BaseException) -> Dict[str, Any]: error = str(exc) or exc.__class__.__name__ await _execute( @@ -103,12 +139,17 @@ async def run(self, input: Dict[str, Any]) -> Dict[str, Any]: timeout = float(payload.get("timeout_seconds") or 120.0) try: await _execute("mark_job_running_activity", job_id, 30) + original_task = _start_original_task(job_id, payload) if payload.get("effort_level") == "high": result = await _execute( "memory_run_pipeline_activity", {**payload, **billing_activity}, timeout, ) + result["original_storage"] = await _await_original_task( + original_task, + payload, + ) await _execute( "mark_job_succeeded_activity", {"job_id": job_id, "result": result}, @@ -211,6 +252,10 @@ async def run(self, input: Dict[str, Any]) -> Dict[str, Any]: ) result["code"] = code + result["original_storage"] = await _await_original_task( + original_task, + payload, + ) await _execute( "mark_job_succeeded_activity", {"job_id": job_id, "result": result}, 30 ) @@ -236,31 +281,63 @@ async def run(self, input: Dict[str, Any]) -> Dict[str, Any]: items = list(payload.get("items") or []) total_timeout = float(payload.get("timeout_seconds") or 3600.0) item_timeout = max(total_timeout / max(len(items), 1), 1.0) - results = [] - for index, item in enumerate(items): + concurrency = max( + int(payload.get("original_batch_item_concurrency") or 1), + 1, + ) + results: List[Any] = [None] * len(items) + completed = 0 + + async def _run_item(index: int, item: Dict[str, Any]): item_payload = dict(item) item_payload["user_id"] = ( item_payload.get("user_id") or payload["user_id"] ) - item_payload.update(billing_activity) - item_result = await _execute( + for key in ( + "original_storage_enabled", + "original_storage_fail_closed", + "original_storage_timeout_seconds", + "original_config", + ): + if key in payload and key not in item_payload: + item_payload[key] = payload[key] + + original_task = _start_original_task(job_id, item_payload) + extraction_task = asyncio.create_task(_execute( "memory_run_pipeline_activity", - item_payload, + {**item_payload, **billing_activity}, item_timeout, + )) + item_result = await extraction_task + item_result["original_storage"] = await _await_original_task( + original_task, + item_payload, ) - results.append(item_result) - await _execute( - "mark_job_progress_activity", - { - "job_id": job_id, - "progress": { - "step": "batch_ingest", - "completed": index + 1, - "total": len(items), + return index, item_result + + for start in range(0, len(items), concurrency): + window = [ + asyncio.create_task(_run_item(index, item)) + for index, item in enumerate( + items[start:start + concurrency], + start=start, + ) + ] + for index, item_result in await asyncio.gather(*window): + results[index] = item_result + completed += 1 + await _execute( + "mark_job_progress_activity", + { + "job_id": job_id, + "progress": { + "step": "batch_ingest", + "completed": completed, + "total": len(items), + }, }, - }, - 30, - ) + 30, + ) result = {"results": results} await _execute( "mark_job_succeeded_activity", {"job_id": job_id, "result": result}, 30 diff --git a/src/api/schemas.py b/src/api/schemas.py index ff10e6de..9c652ff1 100644 --- a/src/api/schemas.py +++ b/src/api/schemas.py @@ -7,7 +7,6 @@ from __future__ import annotations -from datetime import datetime from enum import Enum import re from typing import Any, Dict, List, Optional @@ -195,6 +194,41 @@ class SearchResponse(BaseModel): total: int = 0 +class HybridSearchRequest(UserScopedModel): + """v2-only search across extracted memories and original chunks.""" + query: str = Field(..., min_length=1, max_length=5_000) + user_id: str = Field(..., min_length=1, max_length=256) + domains: List[str] = Field( + default=["profile", "temporal", "summary"], + description="Extracted memory domains to search", + ) + memory_top_k: Optional[int] = Field(default=None, ge=1, le=100) + original_top_k: Optional[int] = Field(default=None, ge=1, le=100) + include_original_chunks: bool = True + + @field_validator("query") + @classmethod + def strip_hybrid_query(cls, v: str) -> str: + return v.strip() + + @field_validator("domains") + @classmethod + def validate_hybrid_domains(cls, v: List[str]) -> List[str]: + allowed = {"profile", "temporal", "summary"} + for d in v: + if d not in allowed: + raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}") + return v + + +class HybridSearchResponse(BaseModel): + memory_results: List[SourceRecord] = Field(default_factory=list) + original_chunks: List[SourceRecord] = Field(default_factory=list) + results: List[SourceRecord] = Field(default_factory=list) + total: int = 0 + original_storage_enabled: bool = False + + # ── Scrape (extract from shared chat links) ──────────────────────────────── class ScrapeRequest(BaseModel): diff --git a/src/config/settings.py b/src/config/settings.py index acc7fc0e..caca7b39 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -131,6 +131,10 @@ class Settings(BaseSettings): default=None, description="AWS secret access key for Bedrock" ) + aws_session_token: Optional[str] = Field( + default=None, + description="Optional AWS session token for temporary credentials" + ) bedrock_region: str = Field( default="us-east-1", description="AWS region for Bedrock" @@ -168,6 +172,101 @@ class Settings(BaseSettings): default=120.0, description="Overall memory ingest timeout in seconds", ) + + original_storage_enabled: bool = Field( + default=False, + description="Enable v2-only raw original storage and original chunk indexing", + ) + original_storage_provider: str = Field( + default="s3", + description="Original storage provider: s3", + ) + original_storage_fail_closed: bool = Field( + default=False, + description="Fail v2 ingest jobs when original storage/indexing fails", + ) + original_storage_timeout_seconds: float = Field( + default=180.0, + description="Temporal activity timeout for original storage/indexing", + ) + original_s3_bucket: Optional[str] = Field( + default=None, + description="S3 bucket for raw original documents", + ) + original_s3_region: str = Field( + default="us-east-1", + description="AWS region for ORIGINAL_S3_BUCKET", + ) + original_s3_prefix: str = Field( + default="originals", + description="Prefix inside the S3 bucket for original documents", + ) + original_s3_endpoint_url: Optional[str] = Field( + default=None, + description="Optional S3-compatible endpoint URL", + ) + original_s3_kms_key_id: Optional[str] = Field( + default=None, + description="Optional KMS key ID/ARN for SSE-KMS encryption", + ) + original_s3_multipart_threshold_bytes: int = Field( + default=8 * 1024 * 1024, + description="Use multipart upload above this serialized object size", + ) + original_s3_multipart_chunk_bytes: int = Field( + default=8 * 1024 * 1024, + description="Multipart upload chunk size for original S3 objects", + ) + original_chunk_size_tokens: int = Field( + default=350, + description="Approximate token target for original searchable chunks", + ) + original_chunk_overlap_tokens: int = Field( + default=40, + description="Approximate token overlap between original searchable chunks", + ) + original_index_batch_size: int = Field( + default=64, + description="Number of original chunks to upsert per vector-store batch", + ) + original_embed_concurrency: int = Field( + default=4, + description="Concurrent embedding calls for original chunks", + ) + original_index_concurrency: int = Field( + default=2, + description="Concurrent vector upsert batches for original chunks", + ) + original_batch_item_concurrency: int = Field( + default=3, + description="Concurrent v2 batch-ingest items when original storage is enabled", + ) + original_max_bytes: int = Field( + default=10 * 1024 * 1024, + description="Maximum serialized original object size accepted for preservation", + ) + original_include_agent_response: bool = Field( + default=True, + description="Include agent_response in stored original content", + ) + original_include_image_url: bool = Field( + default=False, + description="Include image_url in stored original JSON and indexed text", + ) + hybrid_search_memory_top_k: int = Field( + default=10, + description="Default number of extracted memory hits for v2 hybrid search", + ) + hybrid_search_original_top_k: int = Field( + default=10, + description="Default number of original chunk hits for v2 hybrid search", + ) + hybrid_search_min_score: float = Field( + default=0.0, + ge=0.0, + le=1.0, + description="Minimum score for original chunks returned by v2 hybrid search", + ) temporal_address: str = Field( default="localhost:7233", description="Temporal frontend address for durable v2 workflows", diff --git a/src/storage/original.py b/src/storage/original.py new file mode 100644 index 00000000..bd43e3c5 --- /dev/null +++ b/src/storage/original.py @@ -0,0 +1,402 @@ +"""Original document preservation for v2 memory ingest. + +This module stores the raw input in an object store and indexes searchable +chunks in the configured vector backend. It is intentionally independent from +the extraction pipeline so v2 workflows can run both branches in parallel. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import io +import json +import logging +from dataclasses import dataclass +from datetime import datetime, timezone +from functools import partial +from typing import Any, Callable, Dict, List, Mapping, Optional + +from src.config import settings +from src.config.effort import chunk_text, estimate_tokens +from src.storage.base import BaseVectorStore +from src.storage.factory import get_vector_store + +logger = logging.getLogger("xmem.storage.original") + +ORIGINAL_CHUNK_DOMAIN = "original_chunk" + + +class OriginalStorageError(RuntimeError): + """Raised when original preservation cannot complete.""" + + +@dataclass(frozen=True) +class OriginalConfig: + enabled: bool + provider: str + environment: str + bucket: Optional[str] + region: str + prefix: str + endpoint_url: Optional[str] + kms_key_id: Optional[str] + multipart_threshold_bytes: int + multipart_chunk_bytes: int + chunk_size_tokens: int + chunk_overlap_tokens: int + index_batch_size: int + embed_concurrency: int + index_concurrency: int + max_bytes: int + include_agent_response: bool + include_image_url: bool + + +def original_config_snapshot() -> Dict[str, Any]: + """Capture non-secret original-storage config into a durable job payload.""" + return { + "enabled": bool(settings.original_storage_enabled), + "provider": settings.original_storage_provider, + "environment": settings.environment, + "bucket": settings.original_s3_bucket, + "region": settings.original_s3_region, + "prefix": settings.original_s3_prefix, + "endpoint_url": settings.original_s3_endpoint_url, + "kms_key_id": settings.original_s3_kms_key_id, + "multipart_threshold_bytes": int(settings.original_s3_multipart_threshold_bytes), + "multipart_chunk_bytes": int(settings.original_s3_multipart_chunk_bytes), + "chunk_size_tokens": int(settings.original_chunk_size_tokens), + "chunk_overlap_tokens": int(settings.original_chunk_overlap_tokens), + "index_batch_size": int(settings.original_index_batch_size), + "embed_concurrency": int(settings.original_embed_concurrency), + "index_concurrency": int(settings.original_index_concurrency), + "max_bytes": int(settings.original_max_bytes), + "include_agent_response": bool(settings.original_include_agent_response), + "include_image_url": bool(settings.original_include_image_url), + } + + +def _config_from_payload(payload: Mapping[str, Any]) -> OriginalConfig: + raw = dict(payload.get("original_config") or {}) + if not raw: + raw = original_config_snapshot() + return OriginalConfig( + enabled=bool(raw.get("enabled", settings.original_storage_enabled)), + provider=str(raw.get("provider") or "s3").strip().lower(), + environment=str(raw.get("environment") or settings.environment), + bucket=raw.get("bucket") or None, + region=str(raw.get("region") or "us-east-1"), + prefix=str(raw.get("prefix") or "originals").strip("/"), + endpoint_url=raw.get("endpoint_url") or None, + kms_key_id=raw.get("kms_key_id") or None, + multipart_threshold_bytes=max(int(raw.get("multipart_threshold_bytes") or 1), 1), + multipart_chunk_bytes=max(int(raw.get("multipart_chunk_bytes") or 1), 1), + chunk_size_tokens=max(int(raw.get("chunk_size_tokens") or 350), 1), + chunk_overlap_tokens=max(int(raw.get("chunk_overlap_tokens") or 0), 0), + index_batch_size=max(int(raw.get("index_batch_size") or 64), 1), + embed_concurrency=max(int(raw.get("embed_concurrency") or 4), 1), + index_concurrency=max(int(raw.get("index_concurrency") or 2), 1), + max_bytes=max(int(raw.get("max_bytes") or 1), 1), + include_agent_response=bool(raw.get("include_agent_response", True)), + include_image_url=bool(raw.get("include_image_url", False)), + ) + + +class S3OriginalStore: + def __init__(self, cfg: OriginalConfig) -> None: + if not cfg.bucket: + raise OriginalStorageError("ORIGINAL_S3_BUCKET is required.") + if cfg.provider != "s3": + raise OriginalStorageError( + f"Unsupported ORIGINAL_STORAGE_PROVIDER={cfg.provider!r}." + ) + self.cfg = cfg + self._client = None + + @property + def client(self): + if self._client is None: + import boto3 + from botocore.config import Config + + kwargs: Dict[str, Any] = { + "region_name": self.cfg.region, + "config": Config(read_timeout=60), + } + if self.cfg.endpoint_url: + kwargs["endpoint_url"] = self.cfg.endpoint_url + if settings.aws_access_key_id and settings.aws_secret_access_key: + kwargs["aws_access_key_id"] = settings.aws_access_key_id + kwargs["aws_secret_access_key"] = settings.aws_secret_access_key + if settings.aws_session_token: + kwargs["aws_session_token"] = settings.aws_session_token + + self._client = boto3.client("s3", **kwargs) + return self._client + + def put_json(self, key: str, body: bytes) -> Dict[str, Any]: + extra_args = { + "ContentType": "application/json", + "ServerSideEncryption": "AES256", + } + if self.cfg.kms_key_id: + extra_args["ServerSideEncryption"] = "aws:kms" + extra_args["SSEKMSKeyId"] = self.cfg.kms_key_id + + if len(body) >= self.cfg.multipart_threshold_bytes: + from boto3.s3.transfer import TransferConfig + + transfer_config = TransferConfig( + multipart_threshold=self.cfg.multipart_threshold_bytes, + multipart_chunksize=self.cfg.multipart_chunk_bytes, + ) + self.client.upload_fileobj( + io.BytesIO(body), + self.cfg.bucket, + key, + ExtraArgs=extra_args, + Config=transfer_config, + ) + return {"bucket": self.cfg.bucket, "key": key, "etag": None} + + response = self.client.put_object( + Bucket=self.cfg.bucket, + Key=key, + Body=body, + **extra_args, + ) + return { + "bucket": self.cfg.bucket, + "key": key, + "etag": response.get("ETag"), + } + + +def _sha256_text(text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + +def _sha256_bytes(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def _utc_now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _user_hash(user_id: str) -> str: + return _sha256_text(user_id)[:24] + + +def _content_parts(payload: Mapping[str, Any], cfg: OriginalConfig) -> List[str]: + parts = [] + user_query = str(payload.get("user_query") or "").strip() + agent_response = str(payload.get("agent_response") or "").strip() + image_url = str(payload.get("image_url") or "").strip() + + if user_query: + parts.append(f"[User]\n{user_query}") + if cfg.include_agent_response and agent_response: + parts.append(f"[Assistant]\n{agent_response}") + if cfg.include_image_url and image_url: + parts.append(f"[Image URL]\n{image_url}") + return parts + + +def _build_original_json( + payload: Mapping[str, Any], + cfg: OriginalConfig, + created_at: str, + original_doc_id: str, + content_sha256: str, +) -> Dict[str, Any]: + obj: Dict[str, Any] = { + "schema_version": 1, + "original_doc_id": original_doc_id, + "user_id_hash": _user_hash(str(payload.get("user_id") or "")), + "source_type": payload.get("source_type") or "conversation", + "content_sha256": content_sha256, + "created_at": created_at, + "session_datetime": payload.get("session_datetime") or "", + "request_id": payload.get("request_id") or "", + "job_id": payload.get("job_id") or "", + "user_query": payload.get("user_query") or "", + } + if cfg.include_agent_response: + obj["agent_response"] = payload.get("agent_response") or "" + if cfg.include_image_url: + obj["image_url"] = payload.get("image_url") or "" + return obj + + +def _chunk_original_text(text: str, cfg: OriginalConfig) -> List[str]: + text = text.strip() + if not text: + return [] + if estimate_tokens(text) <= cfg.chunk_size_tokens: + return [text] + return chunk_text( + text, + chunk_size_tokens=cfg.chunk_size_tokens, + overlap_tokens=cfg.chunk_overlap_tokens, + ) + + +async def _embed_chunks( + chunks: List[str], + embed_fn: Callable[[str], List[float]], + concurrency: int, +) -> List[List[float]]: + semaphore = asyncio.Semaphore(concurrency) + loop = asyncio.get_running_loop() + + async def _embed(chunk: str) -> List[float]: + async with semaphore: + return await loop.run_in_executor(None, embed_fn, chunk) + + return await asyncio.gather(*[_embed(chunk) for chunk in chunks]) + + +async def _index_chunks( + *, + vector_store: BaseVectorStore, + chunks: List[str], + embeddings: List[List[float]], + metadatas: List[Dict[str, Any]], + ids: List[str], + batch_size: int, + concurrency: int, +) -> None: + semaphore = asyncio.Semaphore(concurrency) + loop = asyncio.get_running_loop() + + async def _upsert_batch(start: int) -> None: + end = min(start + batch_size, len(chunks)) + async with semaphore: + await loop.run_in_executor( + None, + partial( + vector_store.add, + texts=chunks[start:end], + embeddings=embeddings[start:end], + ids=ids[start:end], + metadata=metadatas[start:end], + ), + ) + + await asyncio.gather(*[_upsert_batch(i) for i in range(0, len(chunks), batch_size)]) + + +async def preserve_original( + payload: Mapping[str, Any], + *, + vector_store: Optional[BaseVectorStore] = None, + embed_fn: Optional[Callable[[str], List[float]]] = None, +) -> Dict[str, Any]: + """Store the raw original and index searchable chunks. + + The operation is retry-safe: S3 keys and vector IDs are deterministic for + the same user/content pair, so Temporal can replay the activity. + """ + cfg = _config_from_payload(payload) + if not cfg.enabled: + return {"status": "disabled", "indexed_chunks": 0} + + parts = _content_parts(payload, cfg) + combined_text = "\n\n".join(parts).strip() + if not combined_text: + return {"status": "skipped", "reason": "empty_original", "indexed_chunks": 0} + + user_id = str(payload.get("user_id") or "default") + content_sha256 = _sha256_text(combined_text) + identity = { + "user_id": user_id, + "content_sha256": content_sha256, + "session_datetime": payload.get("session_datetime") or "", + "source_type": payload.get("source_type") or "conversation", + } + original_doc_id = _sha256_text(json.dumps(identity, sort_keys=True))[:32] + created_at = _utc_now_iso() + original_obj = _build_original_json( + payload, cfg, created_at, original_doc_id, content_sha256 + ) + body = json.dumps(original_obj, sort_keys=True, separators=(",", ":")).encode( + "utf-8" + ) + if len(body) > cfg.max_bytes: + raise OriginalStorageError( + f"Original object is {len(body)} bytes; max is {cfg.max_bytes}." + ) + + key = ( + f"{cfg.prefix}/{cfg.environment}/" + f"{_user_hash(user_id)}/{original_doc_id}.json" + ) + store = S3OriginalStore(cfg) + write_result = await asyncio.to_thread(store.put_json, key, body) + + chunks = _chunk_original_text(combined_text, cfg) + if not chunks: + return { + "status": "stored", + "original_doc_id": original_doc_id, + "bucket": write_result["bucket"], + "s3_key": write_result["key"], + "indexed_chunks": 0, + "content_sha256": content_sha256, + } + + if vector_store is None: + vector_store = get_vector_store(namespace=settings.pinecone_namespace) + if embed_fn is None: + from src.pipelines.ingest import embed_text + + embed_fn = embed_text + + embeddings = await _embed_chunks(chunks, embed_fn, cfg.embed_concurrency) + chunk_count = len(chunks) + ids = [ + f"original:{original_doc_id}:chunk:{idx}" + for idx in range(chunk_count) + ] + metadatas = [ + { + "user_id": user_id, + "domain": ORIGINAL_CHUNK_DOMAIN, + "original_doc_id": original_doc_id, + "s3_key": write_result["key"], + "bucket": write_result["bucket"], + "chunk_index": idx, + "chunk_count": chunk_count, + "content_sha256": content_sha256, + "source_type": str(payload.get("source_type") or "conversation"), + "created_at": created_at, + } + for idx in range(chunk_count) + ] + await _index_chunks( + vector_store=vector_store, + chunks=chunks, + embeddings=embeddings, + metadatas=metadatas, + ids=ids, + batch_size=cfg.index_batch_size, + concurrency=cfg.index_concurrency, + ) + + logger.info( + "Preserved original_doc_id=%s chunks=%d s3_key=%s", + original_doc_id, + chunk_count, + write_result["key"], + ) + return { + "status": "stored", + "original_doc_id": original_doc_id, + "bucket": write_result["bucket"], + "s3_key": write_result["key"], + "etag": write_result.get("etag"), + "indexed_chunks": chunk_count, + "content_sha256": content_sha256, + } diff --git a/tests/api/test_v2_hybrid_search.py b/tests/api/test_v2_hybrid_search.py new file mode 100644 index 00000000..9d45e048 --- /dev/null +++ b/tests/api/test_v2_hybrid_search.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from src.api import dependencies as deps +from src.api.routes.v2 import memory as memory_v2 +from src.storage.original import ORIGINAL_CHUNK_DOMAIN + + +class FakeIngestPipeline: + model = SimpleNamespace(model="fake-ingest") + + +class FakeRetrievalPipeline: + model = SimpleNamespace(model="fake-retrieval") + + def __init__(self, vector_store): + self.vector_store = vector_store + self.neo4j = SimpleNamespace() + + +def test_v2_hybrid_search_returns_memory_and_original_chunks( + monkeypatch, + vector_store, +): + vector_store.seed( + "summary-1", + "Extracted memory about deterministic retries.", + {"user_id": "hunter", "domain": "summary"}, + score=0.91, + ) + vector_store.seed( + "original-1", + "Raw source chunk mentioning S3 and original preservation.", + { + "user_id": "hunter", + "domain": ORIGINAL_CHUNK_DOMAIN, + "original_doc_id": "doc-1", + "s3_key": "originals/test/user/doc-1.json", + }, + score=0.83, + ) + vector_store.seed( + "other-user", + "Wrong user source chunk.", + {"user_id": "someone-else", "domain": ORIGINAL_CHUNK_DOMAIN}, + score=0.99, + ) + + deps._init_error = None + deps._pipelines_ready.set() + deps.set_pipelines(FakeIngestPipeline(), FakeRetrievalPipeline(vector_store)) + monkeypatch.setattr(memory_v2.settings, "original_storage_enabled", True, raising=False) + + async def fake_user(): + return {"id": "user-1", "username": "hunter"} + + async def fake_ready(): + return None + + async def fake_rate_limit(): + return None + + app = FastAPI() + app.dependency_overrides[deps.require_api_key] = fake_user + app.dependency_overrides[deps.require_ready] = fake_ready + app.dependency_overrides[deps.enforce_rate_limit] = fake_rate_limit + app.include_router(memory_v2.router) + + response = TestClient(app).post( + "/v2/memory/hybrid-search", + json={ + "query": "where is the original S3 preservation plan?", + "user_id": "hunter", + "domains": ["summary"], + "memory_top_k": 5, + "original_top_k": 5, + }, + ) + + assert response.status_code == 200 + data = response.json()["data"] + assert data["original_storage_enabled"] is True + assert [r["domain"] for r in data["memory_results"]] == ["summary"] + assert [r["domain"] for r in data["original_chunks"]] == [ORIGINAL_CHUNK_DOMAIN] + assert data["total"] == 2 + assert data["original_chunks"][0]["metadata"]["original_doc_id"] == "doc-1" diff --git a/tests/unit/test_original_storage.py b/tests/unit/test_original_storage.py new file mode 100644 index 00000000..f00021bc --- /dev/null +++ b/tests/unit/test_original_storage.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import json + +import pytest + +from src.storage import original + + +class FakeS3OriginalStore: + writes = [] + + def __init__(self, cfg): + self.cfg = cfg + + def put_json(self, key, body): + self.__class__.writes.append((key, body)) + return {"bucket": self.cfg.bucket, "key": key, "etag": '"fake-etag"'} + + +@pytest.mark.asyncio +async def test_preserve_original_stores_s3_json_and_indexes_deterministic_chunks( + monkeypatch, + vector_store, +): + FakeS3OriginalStore.writes = [] + monkeypatch.setattr(original, "S3OriginalStore", FakeS3OriginalStore) + monkeypatch.setattr(original.settings, "environment", "test", raising=False) + + payload = { + "user_id": "alice@example.com", + "user_query": "Please remember the migration plan. " * 80, + "agent_response": "Use deterministic ids and retries. " * 80, + "session_datetime": "2026-06-03T10:00:00Z", + "job_id": "memory_ingest:test", + "original_config": { + "enabled": True, + "provider": "s3", + "bucket": "xmem-originals-test", + "region": "us-east-1", + "prefix": "originals", + "chunk_size_tokens": 80, + "chunk_overlap_tokens": 10, + "index_batch_size": 2, + "embed_concurrency": 2, + "index_concurrency": 1, + "max_bytes": 10_000_000, + "include_agent_response": True, + "include_image_url": False, + }, + } + + result = await original.preserve_original( + payload, + vector_store=vector_store, + embed_fn=lambda _text: [0.0, 0.0, 0.0], + ) + second = await original.preserve_original( + payload, + vector_store=vector_store, + embed_fn=lambda _text: [0.0, 0.0, 0.0], + ) + + assert result["status"] == "stored" + assert result["original_doc_id"] == second["original_doc_id"] + assert result["indexed_chunks"] > 1 + assert "alice@example.com" not in result["s3_key"] + assert all( + record_id.startswith(f"original:{result['original_doc_id']}:chunk:") + for record_id in vector_store.records + ) + assert all( + record["metadata"]["domain"] == original.ORIGINAL_CHUNK_DOMAIN + for record in vector_store.records.values() + ) + + key, body = FakeS3OriginalStore.writes[0] + stored = json.loads(body) + assert key == result["s3_key"] + assert stored["original_doc_id"] == result["original_doc_id"] + assert stored["user_id_hash"] + assert "user_id" not in stored + + +@pytest.mark.asyncio +async def test_preserve_original_disabled_is_noop(vector_store): + result = await original.preserve_original( + { + "user_id": "alice", + "user_query": "remember this", + "original_config": {"enabled": False}, + }, + vector_store=vector_store, + embed_fn=lambda _text: [0.0, 0.0, 0.0], + ) + + assert result == {"status": "disabled", "indexed_chunks": 0} + assert vector_store.records == {} From f793da48af81b7c59c187e24e4d43de7d43b169a Mon Sep 17 00:00:00 2001 From: ved015 Date: Wed, 3 Jun 2026 20:48:23 +0530 Subject: [PATCH 2/4] Address review feedback for v2 original storage --- CHANGELOG.md | 1 + src/api/routes/v2/memory.py | 21 +++++++++++++-------- src/api/routes/v2/workflows.py | 30 +++++++++++++++++++----------- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e069a92..da896677 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- Add v2-only original document preservation in S3 with indexed original chunks and hybrid memory search. - Add modular Razorpay billing, credit wallets, ledger reservations, and v2 memory workflow metering. - Add durable Temporal-backed v2 memory and scanner workflow APIs with job status, retry, cancel, and dead-letter endpoints. - Add modular LoCoMo and BEAM benchmark runners for the Python XMem API. diff --git a/src/api/routes/v2/memory.py b/src/api/routes/v2/memory.py index 2b559d05..99b33dfb 100644 --- a/src/api/routes/v2/memory.py +++ b/src/api/routes/v2/memory.py @@ -355,16 +355,21 @@ async def hybrid_search_memory_v2( try: memory_results: list[SourceRecord] = [] if "profile" in req.domains: - memory_results.extend(memory_v1._search_profile(pipeline, user_id)) + profile_results = await asyncio.to_thread( + memory_v1._search_profile, + pipeline, + user_id, + ) + memory_results.extend(profile_results) if "temporal" in req.domains: - memory_results.extend( - memory_v1._search_temporal( - pipeline, - req.query, - user_id, - memory_top_k, - ) + temporal_results = await asyncio.to_thread( + memory_v1._search_temporal, + pipeline, + req.query, + user_id, + memory_top_k, ) + memory_results.extend(temporal_results) if "summary" in req.domains: memory_results.extend( await memory_v1._search_summary( diff --git a/src/api/routes/v2/workflows.py b/src/api/routes/v2/workflows.py index 62d5f987..e2a67f1c 100644 --- a/src/api/routes/v2/workflows.py +++ b/src/api/routes/v2/workflows.py @@ -303,17 +303,25 @@ async def _run_item(index: int, item: Dict[str, Any]): item_payload[key] = payload[key] original_task = _start_original_task(job_id, item_payload) - extraction_task = asyncio.create_task(_execute( - "memory_run_pipeline_activity", - {**item_payload, **billing_activity}, - item_timeout, - )) - item_result = await extraction_task - item_result["original_storage"] = await _await_original_task( - original_task, - item_payload, - ) - return index, item_result + try: + item_result = await _execute( + "memory_run_pipeline_activity", + {**item_payload, **billing_activity}, + item_timeout, + ) + item_result["original_storage"] = await _await_original_task( + original_task, + item_payload, + ) + original_task = None + return index, item_result + finally: + if original_task and not original_task.done(): + original_task.cancel() + try: + await original_task + except BaseException: + pass for start in range(0, len(items), concurrency): window = [ From c2ee6625a2b6fe63f3bab981ef56714cc029d95d Mon Sep 17 00:00:00 2001 From: ved015 Date: Wed, 3 Jun 2026 20:51:09 +0530 Subject: [PATCH 3/4] Narrow cancelled task cleanup handling --- src/api/routes/v2/workflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/routes/v2/workflows.py b/src/api/routes/v2/workflows.py index e2a67f1c..17f6265b 100644 --- a/src/api/routes/v2/workflows.py +++ b/src/api/routes/v2/workflows.py @@ -320,7 +320,7 @@ async def _run_item(index: int, item: Dict[str, Any]): original_task.cancel() try: await original_task - except BaseException: + except (asyncio.CancelledError, CancelledError): pass for start in range(0, len(items), concurrency): From 0b79b8d261f5373076dc616734b4466add3c5197 Mon Sep 17 00:00:00 2001 From: ved015 Date: Wed, 3 Jun 2026 21:50:00 +0530 Subject: [PATCH 4/4] Address Greptile review feedback --- src/api/routes/v2/memory.py | 3 ++- src/storage/original.py | 4 ---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/api/routes/v2/memory.py b/src/api/routes/v2/memory.py index 99b33dfb..c4931ce6 100644 --- a/src/api/routes/v2/memory.py +++ b/src/api/routes/v2/memory.py @@ -310,11 +310,11 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user async def _search_original_chunks( + pipeline, query: str, user_id: str, top_k: int, ) -> list[SourceRecord]: - pipeline = get_retrieval_pipeline() raw = await pipeline.vector_store.search_by_text( query_text=query, top_k=top_k, @@ -383,6 +383,7 @@ async def hybrid_search_memory_v2( original_chunks: list[SourceRecord] = [] if req.include_original_chunks and settings.original_storage_enabled: original_chunks = await _search_original_chunks( + pipeline, req.query, user_id, original_top_k, diff --git a/src/storage/original.py b/src/storage/original.py index bd43e3c5..0c795b47 100644 --- a/src/storage/original.py +++ b/src/storage/original.py @@ -177,10 +177,6 @@ def _sha256_text(text: str) -> str: return hashlib.sha256(text.encode("utf-8")).hexdigest() -def _sha256_bytes(data: bytes) -> str: - return hashlib.sha256(data).hexdigest() - - def _utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat()