diff --git a/openviking/server/routers/tasks.py b/openviking/server/routers/tasks.py index a7247b604..f712bd8a9 100644 --- a/openviking/server/routers/tasks.py +++ b/openviking/server/routers/tasks.py @@ -29,8 +29,8 @@ async def get_task( tracker = get_task_tracker() task = tracker.get( task_id, - owner_account_id=_ctx.account_id, - owner_user_id=_ctx.user.user_id, + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, ) if not task: raise OpenVikingError( @@ -58,7 +58,7 @@ async def list_tasks( status=status, resource_id=resource_id, limit=limit, - owner_account_id=_ctx.account_id, - owner_user_id=_ctx.user.user_id, + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, ) return Response(status="ok", result=[t.to_dict() for t in tasks]) diff --git a/openviking/service/core.py b/openviking/service/core.py index 89c2634cb..0caab80e6 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -21,6 +21,7 @@ from openviking.service.resource_service import ResourceService from openviking.service.search_service import SearchService from openviking.service.session_service import SessionService +from openviking.service.task_tracker import set_task_tracker from openviking.session import SessionCompressor, create_session_compressor from openviking.storage import VikingDBManager from openviking.storage.collection_schemas import init_context_collection @@ -149,6 +150,7 @@ def _init_storage( lock_expire=tx_cfg.lock_expire, redo_recovery_enabled=tx_cfg.redo_recovery_enabled, ) + set_task_tracker(config.build_task_tracker(self._agfs_client)) @property def _agfs(self) -> Any: diff --git a/openviking/service/reindex_executor.py b/openviking/service/reindex_executor.py index c4b84dad7..334414f7b 100644 --- a/openviking/service/reindex_executor.py +++ b/openviking/service/reindex_executor.py @@ -89,8 +89,8 @@ async def execute( if tracker.has_running( REINDEX_TASK_TYPE, uri, - owner_account_id=ctx.account_id, - owner_user_id=ctx.user.user_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ): raise OpenVikingError( f"URI {uri} already has a reindex in progress", @@ -107,8 +107,8 @@ async def execute( task = tracker.create_if_no_running( REINDEX_TASK_TYPE, uri, - owner_account_id=ctx.account_id, - owner_user_id=ctx.user.user_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ) if task is None: raise OpenVikingError( @@ -376,7 +376,7 @@ async def _run_tracked( ctx: RequestContext, ) -> None: tracker = get_task_tracker() - tracker.start(task_id) + tracker.start(task_id, account_id=ctx.account_id, user_id=ctx.user.user_id) try: result = await self._run( uri=uri, @@ -384,9 +384,9 @@ async def _run_tracked( mode=mode, ctx=ctx, ) - tracker.complete(task_id, result) + tracker.complete(task_id, result, account_id=ctx.account_id, user_id=ctx.user.user_id) except Exception as exc: - tracker.fail(task_id, str(exc)) + tracker.fail(task_id, str(exc), account_id=ctx.account_id, user_id=ctx.user.user_id) async def _reindex_resource( self, diff --git a/openviking/service/resource_service.py b/openviking/service/resource_service.py index 72aa043a7..a38afb570 100644 --- a/openviking/service/resource_service.py +++ b/openviking/service/resource_service.py @@ -295,16 +295,29 @@ async def add_resource( task = task_tracker.create( "add_resource", resource_id=root_uri, - owner_account_id=ctx.account_id, - owner_user_id=ctx.user.user_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ) result["task_id"] = task.task_id if telemetry_id: monitor_started = True - asyncio.create_task(self._monitor_queue_processing(task.task_id, telemetry_id)) + asyncio.create_task( + self._monitor_queue_processing( + task.task_id, + telemetry_id, + ctx.account_id, + ) + ) else: - task_tracker.start(task.task_id) - task_tracker.complete(task.task_id, {"root_uri": root_uri}) + task_tracker.start( + task.task_id, account_id=ctx.account_id, user_id=ctx.user.user_id + ) + task_tracker.complete( + task.task_id, + {"root_uri": root_uri}, + account_id=ctx.account_id, + user_id=ctx.user.user_id, + ) return result except Exception as exc: telemetry.set_error( @@ -322,22 +335,38 @@ async def add_resource( get_request_wait_tracker().cleanup(telemetry_id) unregister_wait_telemetry(telemetry_id) - async def _monitor_queue_processing(self, task_id: str, telemetry_id: str) -> None: + async def _monitor_queue_processing( + self, + task_id: str, + telemetry_id: str, + account_id: str, + user_id: str, + ) -> None: from openviking.service.task_tracker import get_task_tracker task_tracker = get_task_tracker() request_wait_tracker = get_request_wait_tracker() - task_tracker.start(task_id) + task_tracker.start(task_id, account_id=account_id, user_id=user_id) try: await request_wait_tracker.wait_for_request(telemetry_id) status = request_wait_tracker.build_queue_status(telemetry_id) errors = sum(int(group.get("error_count", 0) or 0) for group in status.values()) if errors: - task_tracker.fail(task_id, f"queue processing failed: {status}") + task_tracker.fail( + task_id, + f"queue processing failed: {status}", + account_id=account_id, + user_id=user_id, + ) else: - task_tracker.complete(task_id, {"queue_status": status}) + task_tracker.complete( + task_id, + {"queue_status": status}, + account_id=account_id, + user_id=user_id, + ) except Exception as exc: - task_tracker.fail(task_id, str(exc)) + task_tracker.fail(task_id, str(exc), account_id=account_id, user_id=user_id) finally: request_wait_tracker.cleanup(telemetry_id) unregister_wait_telemetry(telemetry_id) @@ -518,16 +547,30 @@ async def add_skill( task_tracker = get_task_tracker() task = task_tracker.create( "add_skill", - owner_account_id=ctx.account_id, - owner_user_id=ctx.user.user_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ) result["task_id"] = task.task_id if telemetry_id: monitor_started = True - asyncio.create_task(self._monitor_queue_processing(task.task_id, telemetry_id)) + asyncio.create_task( + self._monitor_queue_processing( + task.task_id, + telemetry_id, + ctx.account_id, + ctx.user.user_id, + ) + ) else: - task_tracker.start(task.task_id) - task_tracker.complete(task.task_id, {}) + task_tracker.start( + task.task_id, account_id=ctx.account_id, user_id=ctx.user.user_id + ) + task_tracker.complete( + task.task_id, + {}, + account_id=ctx.account_id, + user_id=ctx.user.user_id, + ) return result finally: diff --git a/openviking/service/session_service.py b/openviking/service/session_service.py index 7fb85d116..e0b9ece58 100644 --- a/openviking/service/session_service.py +++ b/openviking/service/session_service.py @@ -252,7 +252,8 @@ async def get_commit_task(self, task_id: str, ctx: RequestContext) -> Optional[D """Query background commit task status by task_id for the calling owner.""" task = get_task_tracker().get( task_id, - owner_account_id=ctx.account_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ) return task.to_dict() if task else None diff --git a/openviking/service/task_store.py b/openviking/service/task_store.py new file mode 100644 index 000000000..770bbd026 --- /dev/null +++ b/openviking/service/task_store.py @@ -0,0 +1,192 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 +"""Internal storage backends for TaskTracker.""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, Dict, List, Optional, Protocol + +from openviking.pyagfs.exceptions import AGFSAlreadyExistsError + + +class TaskStore(Protocol): + def create(self, task: Any) -> None: ... + + def update(self, task: Any) -> None: ... + + def get( + self, + task_id: str, + *, + account_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: ... + + def list(self, account_id: str, *, user_id: Optional[str] = None) -> List[Dict[str, Any]]: ... + + def delete(self, task_id: str, *, account_id: str, user_id: Optional[str] = None) -> None: ... + + +class InMemoryTaskStore: + """Simple in-process task store.""" + + def __init__(self) -> None: + self._tasks: Dict[str, Dict[str, Any]] = {} + + def create(self, task: Any) -> None: + self._tasks[task.task_id] = _task_to_payload(task) + + def update(self, task: Any) -> None: + self._tasks[task.task_id] = _task_to_payload(task) + + def get( + self, + task_id: str, + *, + account_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + payload = self._tasks.get(task_id) + if payload is None: + return None + if account_id is not None and payload.get("account_id") != account_id: + return None + if user_id is not None and payload.get("user_id") != user_id: + return None + return deepcopy(payload) + + def list(self, account_id: str, *, user_id: Optional[str] = None) -> List[Dict[str, Any]]: + return [ + deepcopy(payload) + for payload in self._tasks.values() + if payload.get("account_id") == account_id + and (user_id is None or payload.get("user_id") == user_id) + ] + + def delete(self, task_id: str, *, account_id: str, user_id: Optional[str] = None) -> None: + payload = self._tasks.get(task_id) + if ( + payload + and payload.get("account_id") == account_id + and (user_id is None or payload.get("user_id") == user_id) + ): + del self._tasks[task_id] + + +class PersistentTaskStore: + """Persist task records into AGFS under account-scoped task directories.""" + + ROOT_PREFIX = "/local" + RESERVED_DIRNAME = "tasks" + + def __init__(self, agfs: Any) -> None: + self._agfs = agfs + + def create(self, task: Any) -> None: + self._write_task(task) + + def update(self, task: Any) -> None: + self._write_task(task) + + def get( + self, + task_id: str, + *, + account_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + if not account_id or not user_id: + return None + path = self._task_path(account_id, user_id, task_id) + try: + raw = self._agfs.read(path) + except Exception: + return None + return json.loads(_decode_bytes(raw)) + + def list(self, account_id: str, *, user_id: Optional[str] = None) -> List[Dict[str, Any]]: + if not user_id: + return [] + directory = self._task_dir(account_id, user_id) + try: + items = self._agfs.ls(directory) + except Exception: + return [] + tasks: List[Dict[str, Any]] = [] + for item in items: + path = item.get("path") or f"{directory}/{item.get('name', '')}" + if not path.endswith(".json"): + continue + try: + raw = self._agfs.read(path) + tasks.append(json.loads(_decode_bytes(raw))) + except Exception: + continue + return tasks + + def delete(self, task_id: str, *, account_id: str, user_id: Optional[str] = None) -> None: + if not user_id: + return + self._agfs.rm(self._task_path(account_id, user_id, task_id), force=True) + + def _write_task(self, task: Any) -> None: + account_id = getattr(task, "account_id", None) + user_id = getattr(task, "user_id", None) + if not account_id or not user_id: + raise ValueError("PersistentTaskStore requires account_id and user_id") + self._ensure_task_dir(account_id, user_id) + self._agfs.write( + self._task_path(account_id, user_id, task.task_id), + json.dumps(_task_to_payload(task), ensure_ascii=False).encode("utf-8"), + ) + + def _ensure_task_dir(self, account_id: str, user_id: str) -> None: + self._mkdir_if_missing(self._account_dir(account_id)) + self._mkdir_if_missing(self._task_root_dir(account_id)) + self._mkdir_if_missing(self._task_dir(account_id, user_id)) + + def _mkdir_if_missing(self, path: str) -> None: + try: + self._agfs.mkdir(path) + except AGFSAlreadyExistsError: + return + except Exception as exc: + if "already exists" in str(exc).lower(): + return + raise + + def _account_dir(self, account_id: str) -> str: + return f"{self.ROOT_PREFIX}/{account_id}" + + def _task_root_dir(self, account_id: str) -> str: + return f"{self._account_dir(account_id)}/{self.RESERVED_DIRNAME}" + + def _task_dir(self, account_id: str, user_id: str) -> str: + return f"{self._task_root_dir(account_id)}/{user_id}" + + def _task_path(self, account_id: str, user_id: str, task_id: str) -> str: + return f"{self._task_dir(account_id, user_id)}/{task_id}.json" + + +def _task_to_payload(task: Any) -> Dict[str, Any]: + status = getattr(task, "status", None) + return { + "task_id": task.task_id, + "task_type": task.task_type, + "status": status.value if hasattr(status, "value") else status, + "created_at": task.created_at, + "updated_at": task.updated_at, + "resource_id": task.resource_id, + "account_id": task.account_id, + "user_id": task.user_id, + "result": deepcopy(task.result), + "error": task.error, + } + + +def _decode_bytes(raw: Any) -> str: + if isinstance(raw, bytes): + return raw.decode("utf-8") + return str(raw) diff --git a/openviking/service/task_tracker.py b/openviking/service/task_tracker.py index 27bdad1e7..ff65363c5 100644 --- a/openviking/service/task_tracker.py +++ b/openviking/service/task_tracker.py @@ -3,14 +3,13 @@ """ Async Task Tracker for OpenViking. -Provides a lightweight, in-memory registry for tracking background operations +Provides a lightweight registry for tracking background operations (e.g. session commit with wait=false). Callers receive a task_id that can be polled via the /tasks API to check completion status, results, or errors. Design decisions: - - v1 is pure in-memory (no persistence). Tasks are lost on restart. - Thread-safe (QueueManager workers run in separate threads). - - TTL-based cleanup prevents unbounded memory growth. + - TTL-based cleanup still applies to in-memory cache. - Error messages are sanitized to avoid leaking sensitive data. """ @@ -25,6 +24,7 @@ from typing import Any, Dict, List, Optional from uuid import uuid4 +from openviking.service.task_store import InMemoryTaskStore, TaskStore from openviking_cli.utils.logger import get_logger logger = get_logger(__name__) @@ -49,8 +49,8 @@ class TaskRecord: created_at: float = field(default_factory=time.time) updated_at: float = field(default_factory=time.time) resource_id: Optional[str] = None # e.g. session_id - owner_account_id: Optional[str] = None - owner_user_id: Optional[str] = None + account_id: Optional[str] = None + user_id: Optional[str] = None result: Optional[Dict[str, Any]] = None error: Optional[str] = None @@ -60,8 +60,8 @@ def to_dict(self) -> Dict[str, Any]: d["status"] = self.status.value d["created_at_iso"] = datetime.fromtimestamp(self.created_at, tz=timezone.utc).isoformat() d["updated_at_iso"] = datetime.fromtimestamp(self.updated_at, tz=timezone.utc).isoformat() - d.pop("owner_account_id", None) - d.pop("owner_user_id", None) + d.pop("account_id", None) + d.pop("user_id", None) return d @@ -81,6 +81,13 @@ def get_task_tracker() -> "TaskTracker": return _instance +def set_task_tracker(tracker: "TaskTracker") -> None: + """Replace the global TaskTracker singleton.""" + global _instance + with _init_lock: + _instance = tracker + + def reset_task_tracker() -> None: """Reset singleton (for testing).""" global _instance @@ -109,7 +116,7 @@ def _sanitize_error(error: str) -> str: class TaskTracker: - """In-memory async task tracker with TTL-based cleanup. + """Async task tracker with pluggable storage and in-memory compatibility cache. Thread-safe: all mutations go through ``_lock``. """ @@ -119,11 +126,16 @@ class TaskTracker: TTL_FAILED = 604_800 # 7 days CLEANUP_INTERVAL = 300 # 5 minutes - def __init__(self) -> None: + def __init__(self, store: Optional[TaskStore] = None) -> None: + self._store = store or InMemoryTaskStore() self._tasks: Dict[str, TaskRecord] = {} self._lock = threading.Lock() self._cleanup_task: Optional[asyncio.Task] = None - logger.info("[TaskTracker] Initialized (in-memory, max_tasks=%d)", self.MAX_TASKS) + logger.info( + "[TaskTracker] Initialized (store=%s, max_tasks=%d)", + self._store.__class__.__name__, + self.MAX_TASKS, + ) # ── Lifecycle ── @@ -165,6 +177,9 @@ def _evict_expired(self) -> None: elif t.status == TaskStatus.FAILED and (now - t.updated_at) > self.TTL_FAILED: to_delete.append(tid) for tid in to_delete: + task = self._tasks[tid] + if isinstance(self._store, InMemoryTaskStore) and task.account_id: + self._store.delete(tid, account_id=task.account_id, user_id=task.user_id) del self._tasks[tid] # FIFO eviction if still over limit @@ -172,6 +187,9 @@ def _evict_expired(self) -> None: sorted_tasks = sorted(self._tasks.items(), key=lambda x: x[1].created_at) excess = len(self._tasks) - self.MAX_TASKS for tid, _ in sorted_tasks[:excess]: + task = self._tasks[tid] + if isinstance(self._store, InMemoryTaskStore) and task.account_id: + self._store.delete(tid, account_id=task.account_id, user_id=task.user_id) del self._tasks[tid] if to_delete: @@ -180,21 +198,21 @@ def _evict_expired(self) -> None: @staticmethod def _matches_owner( task: TaskRecord, - owner_account_id: Optional[str] = None, - owner_user_id: Optional[str] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> bool: """Return True when a task belongs to the requested owner filter.""" - if owner_account_id is not None and task.owner_account_id != owner_account_id: + if account_id is not None and task.account_id != account_id: return False - if owner_user_id is not None and task.owner_user_id != owner_user_id: + if user_id is not None and task.user_id != user_id: return False return True @staticmethod - def _validate_owner(owner_account_id: str, owner_user_id: str) -> None: + def _validate_owner(account_id: str, user_id: str) -> None: """Reject ownerless task creation for user-originated background work.""" - if not owner_account_id or not owner_user_id: - raise ValueError("Task ownership requires non-empty owner_account_id and owner_user_id") + if not account_id or not user_id: + raise ValueError("Task ownership requires non-empty account_id and user_id") # ── CRUD ── @@ -203,20 +221,21 @@ def create( task_type: str, resource_id: Optional[str] = None, *, - owner_account_id: str, - owner_user_id: str, + account_id: str, + user_id: str, ) -> TaskRecord: """Register a new pending task. Returns a snapshot copy.""" - self._validate_owner(owner_account_id, owner_user_id) + self._validate_owner(account_id, user_id) task = TaskRecord( task_id=str(uuid4()), task_type=task_type, resource_id=resource_id, - owner_account_id=owner_account_id, - owner_user_id=owner_user_id, + account_id=account_id, + user_id=user_id, ) with self._lock: self._tasks[task.task_id] = task + self._store.create(task) logger.debug( "[TaskTracker] Created task %s type=%s resource=%s", task.task_id, @@ -230,21 +249,21 @@ def create_if_no_running( task_type: str, resource_id: str, *, - owner_account_id: str, - owner_user_id: str, + account_id: str, + user_id: str, ) -> Optional[TaskRecord]: """Atomically check for running tasks and create a new one if none exist. Returns TaskRecord on success, None if a running task already exists. This eliminates the race condition between has_running() and create(). """ - self._validate_owner(owner_account_id, owner_user_id) + self._validate_owner(account_id, user_id) with self._lock: # Check for existing running tasks has_active = any( t.task_type == task_type and t.resource_id == resource_id - and self._matches_owner(t, owner_account_id, owner_user_id) + and self._matches_owner(t, account_id, user_id) and t.status in (TaskStatus.PENDING, TaskStatus.RUNNING) for t in self._tasks.values() ) @@ -255,10 +274,11 @@ def create_if_no_running( task_id=str(uuid4()), task_type=task_type, resource_id=resource_id, - owner_account_id=owner_account_id, - owner_user_id=owner_user_id, + account_id=account_id, + user_id=user_id, ) self._tasks[task.task_id] = task + self._store.create(task) logger.debug( "[TaskTracker] Created task %s type=%s resource=%s", task.task_id, @@ -267,44 +287,71 @@ def create_if_no_running( ) return self._copy(task) - def start(self, task_id: str) -> None: + def start( + self, + task_id: str, + account_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> None: """Transition task to RUNNING.""" with self._lock: - task = self._tasks.get(task_id) + task = self._load_for_update(task_id, account_id, user_id) if task: task.status = TaskStatus.RUNNING task.updated_at = time.time() + self._tasks[task.task_id] = task + self._store.update(task) - def complete(self, task_id: str, result: Optional[Dict[str, Any]] = None) -> None: + def complete( + self, + task_id: str, + result: Optional[Dict[str, Any]] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> None: """Transition task to COMPLETED with optional result.""" with self._lock: - task = self._tasks.get(task_id) + task = self._load_for_update(task_id, account_id, user_id) if task: task.status = TaskStatus.COMPLETED task.result = result task.updated_at = time.time() + self._tasks[task.task_id] = task + self._store.update(task) logger.info("[TaskTracker] Task %s completed", task_id) - def fail(self, task_id: str, error: str) -> None: + def fail( + self, + task_id: str, + error: str, + account_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> None: """Transition task to FAILED with sanitized error.""" with self._lock: - task = self._tasks.get(task_id) + task = self._load_for_update(task_id, account_id, user_id) if task: task.status = TaskStatus.FAILED task.error = _sanitize_error(error) task.updated_at = time.time() + self._tasks[task.task_id] = task + self._store.update(task) logger.warning("[TaskTracker] Task %s failed: %s", task_id, _sanitize_error(error)) def get( self, task_id: str, - owner_account_id: Optional[str] = None, - owner_user_id: Optional[str] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> Optional[TaskRecord]: """Look up a single task. Returns a snapshot copy (None if not found).""" with self._lock: task = self._tasks.get(task_id) - if task is None or not self._matches_owner(task, owner_account_id, owner_user_id): + if task is None and account_id is not None: + task = self._load_from_store(task_id, account_id, user_id) + if task is not None: + self._tasks[task.task_id] = task + if task is None or not self._matches_owner(task, account_id, user_id): return None return self._copy(task) @@ -314,15 +361,18 @@ def list_tasks( status: Optional[str] = None, resource_id: Optional[str] = None, limit: int = 50, - owner_account_id: Optional[str] = None, - owner_user_id: Optional[str] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> List[TaskRecord]: """List tasks with optional filters. Most-recent first. Returns snapshot copies.""" with self._lock: + if account_id is not None: + for task in self._load_all_from_store(account_id, user_id): + self._tasks[task.task_id] = task tasks = [ self._copy(t) for t in self._tasks.values() - if self._matches_owner(t, owner_account_id, owner_user_id) + if self._matches_owner(t, account_id, user_id) ] if task_type: tasks = [t for t in tasks if t.task_type == task_type] @@ -337,19 +387,58 @@ def has_running( self, task_type: str, resource_id: str, - owner_account_id: Optional[str] = None, - owner_user_id: Optional[str] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> bool: """Check if there is already a running task for the given type+resource.""" with self._lock: + if account_id is not None: + for task in self._load_all_from_store(account_id, user_id): + self._tasks[task.task_id] = task return any( t.task_type == task_type and t.resource_id == resource_id - and self._matches_owner(t, owner_account_id, owner_user_id) + and self._matches_owner(t, account_id, user_id) and t.status in (TaskStatus.PENDING, TaskStatus.RUNNING) for t in self._tasks.values() ) + def _load_for_update( + self, + task_id: str, + account_id: Optional[str], + user_id: Optional[str], + ) -> Optional[TaskRecord]: + task = self._tasks.get(task_id) + if task is not None: + return task + if account_id is None or user_id is None: + return None + return self._load_from_store(task_id, account_id, user_id) + + @staticmethod + def _record_from_payload(payload: Dict[str, Any]) -> TaskRecord: + data = dict(payload) + data["status"] = TaskStatus(data["status"]) + return TaskRecord(**data) + + def _load_from_store( + self, + task_id: str, + account_id: str, + user_id: Optional[str], + ) -> Optional[TaskRecord]: + payload = self._store.get(task_id, account_id=account_id, user_id=user_id) + if payload is None: + return None + return self._record_from_payload(payload) + + def _load_all_from_store(self, account_id: str, user_id: Optional[str]) -> List[TaskRecord]: + return [ + self._record_from_payload(payload) + for payload in self._store.list(account_id, user_id=user_id) + ] + @staticmethod def _copy(task: TaskRecord) -> TaskRecord: """Return a defensive copy of a TaskRecord.""" diff --git a/openviking/session/session.py b/openviking/session/session.py index e5bacb702..f38cc38e6 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -705,8 +705,8 @@ async def commit_async(self, keep_recent_count: int = 0) -> Dict[str, Any]: task = tracker.create( "session_commit", resource_id=self.session_id, - owner_account_id=self.ctx.account_id, - owner_user_id=self.ctx.user.user_id, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, ) asyncio.create_task( @@ -773,10 +773,12 @@ async def _run_memory_extraction( task_id, f"Previous archive archive_{archive_index - 1:03d} failed; " "cannot continue session commit", + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, ) return - tracker.start(task_id) + tracker.start(task_id, account_id=self.ctx.account_id, user_id=self.ctx.user.user_id) request_wait_tracker.register_request(telemetry.telemetry_id) register_telemetry(telemetry) try: @@ -979,6 +981,8 @@ async def _noop_agent(): }, }, }, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, ) logger.info(f"Session {self.session_id} memory extraction completed") except Exception as e: @@ -989,7 +993,9 @@ async def _noop_agent(): stage="memory_extraction", error=str(e), ) - tracker.fail(task_id, str(e)) + tracker.fail( + task_id, str(e), account_id=self.ctx.account_id, user_id=self.ctx.user.user_id + ) logger.exception(f"Memory extraction failed for session {self.session_id}") async def _write_done_file( diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index c5837e762..9bf623e89 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -1629,7 +1629,7 @@ def _uri_to_path(self, uri: str, ctx: Optional[RequestContext] = None) -> str: safe_parts = [self._shorten_component(p, self._MAX_FILENAME_BYTES) for p in parts] return f"/local/{account_id}/{'/'.join(safe_parts)}" - _INTERNAL_NAMES = {"_system", ".path.ovlock"} + _INTERNAL_NAMES = {"_system", "tasks", ".path.ovlock"} _ROOT_PATH = "/local" def _ls_entries(self, path: str) -> List[Dict[str, Any]]: diff --git a/openviking_cli/utils/config/storage_config.py b/openviking_cli/utils/config/storage_config.py index 59aa2430c..1ea24a61a 100644 --- a/openviking_cli/utils/config/storage_config.py +++ b/openviking_cli/utils/config/storage_config.py @@ -1,7 +1,7 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: AGPL-3.0 from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Literal from pydantic import BaseModel, Field, model_validator @@ -14,6 +14,15 @@ logger = get_logger(__name__) +class TaskTrackerConfig(BaseModel): + """Configuration for async task tracking backend.""" + + backend: Literal["memory", "persistent"] = Field( + default="memory", + description="Task tracker backend. 'persistent' enables cross-instance task lookup.", + ) + + class StorageConfig(BaseModel): """Configuration for storage backend. @@ -36,6 +45,11 @@ class StorageConfig(BaseModel): description="VectorDB backend configuration", ) + task_tracker: TaskTrackerConfig = Field( + default_factory=TaskTrackerConfig, + description="Task tracker backend configuration", + ) + params: Dict[str, Any] = Field( default_factory=dict, description="Additional storage-specific parameters" ) @@ -75,3 +89,12 @@ def get_upload_temp_dir(self) -> Path: upload_temp_dir = workspace_path / "temp" / "upload" upload_temp_dir.mkdir(parents=True, exist_ok=True) return upload_temp_dir + + def build_task_tracker(self, agfs: Any): + """Build a TaskTracker from storage config.""" + from openviking.service.task_store import PersistentTaskStore + from openviking.service.task_tracker import TaskTracker + + if self.task_tracker.backend == "memory": + return TaskTracker() + return TaskTracker(store=PersistentTaskStore(agfs)) diff --git a/tests/misc/test_vikingfs_uri_guard.py b/tests/misc/test_vikingfs_uri_guard.py index febdba1c2..36382a902 100644 --- a/tests/misc/test_vikingfs_uri_guard.py +++ b/tests/misc/test_vikingfs_uri_guard.py @@ -194,3 +194,15 @@ async def test_grep_with_agfs_maps_dot_match_to_query_root_uri(self) -> None: assert result["matches"][0]["uri"] == "viking://resources/test-root" assert result["matches"][0]["line"] == 1 assert result["matches"][0]["content"] == "act" + + def test_ls_entries_hides_reserved_tasks_dir_under_account_root(self) -> None: + fs = _make_viking_fs() + fs.agfs.ls.return_value = [ + {"name": "resources", "isDir": True}, + {"name": "tasks", "isDir": True}, + {"name": "_system", "isDir": True}, + ] + + entries = fs._ls_entries("/local/default") + + assert [entry["name"] for entry in entries] == ["resources"] diff --git a/tests/server/test_auth.py b/tests/server/test_auth.py index c14337501..b2692041d 100644 --- a/tests/server/test_auth.py +++ b/tests/server/test_auth.py @@ -323,14 +323,14 @@ async def test_task_endpoints_are_user_scoped(): alice_task = tracker.create( "session_commit", resource_id="alice-session", - owner_account_id=account_id, - owner_user_id="alice", + account_id=account_id, + user_id="alice", ) bob_task = tracker.create( "session_commit", resource_id="bob-session", - owner_account_id=account_id, - owner_user_id="bob", + account_id=account_id, + user_id="bob", ) alice_app = _build_task_http_test_app( diff --git a/tests/test_session_task_tracking.py b/tests/test_session_task_tracking.py index c8c2d15b3..617580ac7 100644 --- a/tests/test_session_task_tracking.py +++ b/tests/test_session_task_tracking.py @@ -69,13 +69,13 @@ async def mock_commit(_sid, _ctx): task = tracker.create( "session_commit", resource_id=_sid, - owner_account_id=_ctx.account_id, - owner_user_id=_ctx.user.user_id, + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, ) archive_uri = f"viking://session/test/{_sid}/history/archive_001" async def _background(): - tracker.start(task.task_id) + tracker.start(task.task_id, account_id=_ctx.account_id, user_id=_ctx.user.user_id) try: if started: started.set() @@ -96,9 +96,19 @@ async def _background(): } if result_overrides: final_result.update(result_overrides) - tracker.complete(task.task_id, final_result) + tracker.complete( + task.task_id, + final_result, + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, + ) except Exception as e: - tracker.fail(task.task_id, str(e)) + tracker.fail( + task.task_id, + str(e), + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, + ) asyncio.create_task(_background()) @@ -336,11 +346,20 @@ async def fake_add_resource(**kwargs): task = tracker.create( "add_resource", resource_id=root_uri, - owner_account_id=kwargs["ctx"].account_id, - owner_user_id=kwargs["ctx"].user.user_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.start( + task.task_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.complete( + task.task_id, + {"root_uri": root_uri}, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, ) - tracker.start(task.task_id) - tracker.complete(task.task_id, {"root_uri": root_uri}) return {"status": "success", "root_uri": root_uri, "task_id": task.task_id} service.resources.add_resource = fake_add_resource @@ -389,14 +408,23 @@ async def fake_add_resource(**kwargs): task = tracker.create( "add_resource", resource_id=root_uri, - owner_account_id=kwargs["ctx"].account_id, - owner_user_id=kwargs["ctx"].user.user_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, ) async def _background(): - tracker.start(task.task_id) + tracker.start( + task.task_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) await asyncio.sleep(0.05) - tracker.complete(task.task_id, {"root_uri": root_uri}) + tracker.complete( + task.task_id, + {"root_uri": root_uri}, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) asyncio.create_task(_background()) return {"status": "success", "root_uri": root_uri, "task_id": task.task_id} @@ -434,11 +462,20 @@ async def fake_add_resource(**kwargs): task = tracker.create( "add_resource", resource_id=root_uri, - owner_account_id=kwargs["ctx"].account_id, - owner_user_id=kwargs["ctx"].user.user_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.start( + task.task_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.complete( + task.task_id, + {"root_uri": root_uri}, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, ) - tracker.start(task.task_id) - tracker.complete(task.task_id, {"root_uri": root_uri}) return {"status": "success", "root_uri": root_uri, "task_id": task.task_id} service.resources.add_resource = fake_add_resource @@ -469,11 +506,20 @@ async def fake_add_skill(**kwargs): tracker = get_task_tracker() task = tracker.create( "add_skill", - owner_account_id=kwargs["ctx"].account_id, - owner_user_id=kwargs["ctx"].user.user_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.start( + task.task_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.complete( + task.task_id, + {}, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, ) - tracker.start(task.task_id) - tracker.complete(task.task_id, {}) return {"status": "success", "task_id": task.task_id} service.resources.add_skill = fake_add_skill diff --git a/tests/test_task_backend_config.py b/tests/test_task_backend_config.py new file mode 100644 index 000000000..10d41b0fe --- /dev/null +++ b/tests/test_task_backend_config.py @@ -0,0 +1,45 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 + +from openviking.service.task_store import InMemoryTaskStore, PersistentTaskStore +from openviking.service.task_tracker import TaskTracker +from openviking_cli.utils.config.storage_config import StorageConfig + + +class _FakeAgfs: + def mkdir(self, path: str, mode: str = "755"): + return {"message": "created", "mode": mode} + + def write(self, path: str, data): + return "OK" + + def read(self, path: str, offset: int = 0, size: int = -1, stream: bool = False): + raise FileNotFoundError(path) + + def ls(self, path: str = "/"): + return [] + + def rm(self, path: str, recursive: bool = False, force: bool = True): + return {"message": "deleted"} + + +def test_storage_config_defaults_task_backend_to_memory(): + config = StorageConfig() + assert config.task_tracker.backend == "memory" + + +def test_storage_config_accepts_memory_task_backend(): + config = StorageConfig(task_tracker={"backend": "memory"}) + assert config.task_tracker.backend == "memory" + + +def test_storage_config_builds_memory_task_tracker(): + tracker = StorageConfig(task_tracker={"backend": "memory"}).build_task_tracker(_FakeAgfs()) + assert isinstance(tracker, TaskTracker) + assert isinstance(tracker._store, InMemoryTaskStore) + + +def test_storage_config_builds_persistent_task_tracker(): + tracker = StorageConfig(task_tracker={"backend": "persistent"}).build_task_tracker(_FakeAgfs()) + assert isinstance(tracker, TaskTracker) + assert isinstance(tracker._store, PersistentTaskStore) diff --git a/tests/test_task_tracker.py b/tests/test_task_tracker.py index e3eaf597f..a6b4d7d74 100644 --- a/tests/test_task_tracker.py +++ b/tests/test_task_tracker.py @@ -3,12 +3,15 @@ """Unit tests for TaskTracker.""" +import json import time import pytest +from openviking.pyagfs.exceptions import AGFSAlreadyExistsError from openviking.server.identity import RequestContext, Role from openviking.service.session_service import SessionService +from openviking.service.task_store import InMemoryTaskStore, PersistentTaskStore from openviking.service.task_tracker import ( TaskStatus, TaskTracker, @@ -34,8 +37,8 @@ def tracker() -> TaskTracker: def _owner_kwargs(account_id: str = "acme", user_id: str = "alice"): return { - "owner_account_id": account_id, - "owner_user_id": user_id, + "account_id": account_id, + "user_id": user_id, } @@ -46,6 +49,60 @@ def _make_ctx(account_id: str = "acme", user_id: str = "alice") -> RequestContex ) +class _FakeAgfs: + def __init__(self): + self.files = {} + self.dirs = {"/", "/local"} + + def mkdir(self, path: str, mode: str = "755"): + self.dirs.add(path.rstrip("/") or "/") + return {"message": "created", "mode": mode} + + def write(self, path: str, data): + if isinstance(data, str): + data = data.encode("utf-8") + self.files[path] = data + parent = path.rsplit("/", 1)[0] or "/" + self.dirs.add(parent) + return "OK" + + def read(self, path: str, offset: int = 0, size: int = -1, stream: bool = False): + if path not in self.files: + raise FileNotFoundError(path) + data = self.files[path] + if size >= 0: + return data[offset : offset + size] + return data[offset:] + + def ls(self, path: str = "/"): + prefix = path.rstrip("/") or "/" + if prefix not in self.dirs: + return [] + children = {} + for directory in self.dirs: + if directory in {prefix, "/"}: + continue + if directory.startswith(prefix + "/"): + name = directory[len(prefix) + 1 :].split("/", 1)[0] + if name: + children[name] = {"name": name, "path": f"{prefix}/{name}", "is_dir": True} + for file_path in self.files: + if file_path.startswith(prefix + "/"): + name = file_path[len(prefix) + 1 :].split("/", 1)[0] + if name and "/" not in file_path[len(prefix) + 1 :]: + children[name] = {"name": name, "path": f"{prefix}/{name}", "is_dir": False} + return list(children.values()) + + +class _FakeAgfsExistingDir(_FakeAgfs): + def mkdir(self, path: str, mode: str = "755"): + normalized = path.rstrip("/") or "/" + if normalized in self.dirs: + raise AGFSAlreadyExistsError(f"already exists: {path}") + self.dirs.add(normalized) + return {"message": "created", "mode": mode} + + # ── Basic CRUD ── @@ -131,15 +188,15 @@ def test_get_hides_task_from_other_owner(tracker: TaskTracker): task = tracker.create( "session_commit", resource_id="s1", - owner_account_id="acme", - owner_user_id="alice", + account_id="acme", + user_id="alice", ) assert ( tracker.get( task.task_id, - owner_account_id="acme", - owner_user_id="bob", + account_id="acme", + user_id="bob", ) is None ) @@ -149,17 +206,17 @@ def test_list_tasks_filters_by_owner(tracker: TaskTracker): tracker.create( "session_commit", resource_id="alice-task", - owner_account_id="acme", - owner_user_id="alice", + account_id="acme", + user_id="alice", ) tracker.create( "session_commit", resource_id="bob-task", - owner_account_id="acme", - owner_user_id="bob", + account_id="acme", + user_id="bob", ) - tasks = tracker.list_tasks(owner_account_id="acme", owner_user_id="alice") + tasks = tracker.list_tasks(account_id="acme", user_id="alice") assert len(tasks) == 1 assert tasks[0].resource_id == "alice-task" @@ -212,14 +269,14 @@ def test_create_if_no_running_isolated_by_owner(tracker: TaskTracker): alice_task = tracker.create_if_no_running( "reindex", "viking://resources/demo", - owner_account_id="acme", - owner_user_id="alice", + account_id="acme", + user_id="alice", ) bob_task = tracker.create_if_no_running( "reindex", "viking://resources/demo", - owner_account_id="acme", - owner_user_id="bob", + account_id="acme", + user_id="bob", ) assert alice_task is not None @@ -246,8 +303,8 @@ def test_to_dict(tracker: TaskTracker): assert isinstance(d["created_at_iso"], str) assert "T" in d["created_at_iso"] assert isinstance(d["updated_at_iso"], str) - assert "owner_account_id" not in d - assert "owner_user_id" not in d + assert "account_id" not in d + assert "user_id" not in d # ── Sanitization ── @@ -328,6 +385,71 @@ def test_singleton_reset(): assert t1 is not t2 +def test_persistent_store_cross_tracker_visibility(): + agfs = _FakeAgfs() + store = PersistentTaskStore(agfs) + tracker1 = TaskTracker(store=store) + tracker2 = TaskTracker(store=store) + + task = tracker1.create("session_commit", resource_id="sess-123", **_owner_kwargs()) + tracker1.start(task.task_id, account_id="acme", user_id="alice") + tracker1.complete(task.task_id, {"ok": True}, account_id="acme", user_id="alice") + + loaded = tracker2.get(task.task_id, account_id="acme", user_id="alice") + + assert loaded is not None + assert loaded.status == TaskStatus.COMPLETED + assert loaded.result == {"ok": True} + + +def test_persistent_store_writes_task_record_json(): + agfs = _FakeAgfs() + store = PersistentTaskStore(agfs) + tracker = TaskTracker(store=store) + + task = tracker.create("add_resource", resource_id="viking://resources/demo", **_owner_kwargs()) + + raw = agfs.files[f"/local/acme/tasks/alice/{task.task_id}.json"] + payload = json.loads(raw.decode("utf-8")) + + assert payload["task_id"] == task.task_id + assert payload["task_type"] == "add_resource" + assert payload["account_id"] == "acme" + assert payload["user_id"] == "alice" + assert "schema_version" not in payload + + +def test_inmemory_store_keeps_tasktracker_tasks_dict(): + tracker = TaskTracker(store=InMemoryTaskStore()) + task = tracker.create("session_commit", **_owner_kwargs()) + assert task.task_id in tracker._tasks + + +def test_persistent_store_survives_tracker_reset(): + agfs = _FakeAgfs() + tracker1 = TaskTracker(store=PersistentTaskStore(agfs)) + task = tracker1.create("session_commit", resource_id="sess-123", **_owner_kwargs()) + tracker1.start(task.task_id, account_id="acme", user_id="alice") + + tracker2 = TaskTracker(store=PersistentTaskStore(agfs)) + loaded = tracker2.get(task.task_id, account_id="acme", user_id="alice") + + assert loaded is not None + assert loaded.status == TaskStatus.RUNNING + + +def test_persistent_store_ignores_existing_task_dirs(): + agfs = _FakeAgfsExistingDir() + tracker = TaskTracker(store=PersistentTaskStore(agfs)) + + first = tracker.create("session_commit", resource_id="sess-1", **_owner_kwargs()) + second = tracker.create("session_commit", resource_id="sess-2", **_owner_kwargs()) + + assert first.task_id != second.task_id + assert agfs.files[f"/local/acme/tasks/alice/{first.task_id}.json"] + assert agfs.files[f"/local/acme/tasks/alice/{second.task_id}.json"] + + def test_create_requires_owner(tracker: TaskTracker): with pytest.raises(TypeError): tracker.create("session_commit", resource_id="sess-123") @@ -343,8 +465,8 @@ def test_create_rejects_blank_owner_values(tracker: TaskTracker): tracker.create( "session_commit", resource_id="sess-123", - owner_account_id="", - owner_user_id="alice", + account_id="", + user_id="alice", ) @@ -361,3 +483,17 @@ async def test_session_service_get_commit_task_is_owner_scoped(): assert owner_result["task_id"] == task.task_id assert owner_result["resource_id"] == "sess-123" assert other_result is None + + +@pytest.mark.asyncio +async def test_session_service_get_commit_task_also_filters_account(): + tracker = get_task_tracker() + task = tracker.create("session_commit", resource_id="sess-123", **_owner_kwargs()) + service = SessionService() + + other_account_result = await service.get_commit_task( + task.task_id, + _make_ctx(account_id="other-acme", user_id="alice"), + ) + + assert other_account_result is None