From eb12a5b22b840f4a316c8656b690e51782abbc9a Mon Sep 17 00:00:00 2001 From: "zhoujiahui.01" Date: Sat, 9 May 2026 16:43:44 +0800 Subject: [PATCH 1/2] feat(tracker): persist task tracker state across instances fix: default task tracker to memory backend --- openviking/service/core.py | 2 + openviking/service/reindex_executor.py | 6 +- openviking/service/resource_service.py | 48 +++-- openviking/service/session_service.py | 1 + openviking/service/task_store.py | 164 ++++++++++++++++++ openviking/service/task_tracker.py | 92 ++++++++-- openviking/session/session.py | 6 +- openviking/storage/viking_fs.py | 2 +- openviking_cli/utils/config/storage_config.py | 25 ++- tests/misc/test_vikingfs_uri_guard.py | 12 ++ tests/test_task_backend_config.py | 45 +++++ tests/test_task_tracker.py | 136 +++++++++++++++ 12 files changed, 509 insertions(+), 30 deletions(-) create mode 100644 openviking/service/task_store.py create mode 100644 tests/test_task_backend_config.py diff --git a/openviking/service/core.py b/openviking/service/core.py index 89c2634cb..0caab80e6 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -21,6 +21,7 @@ from openviking.service.resource_service import ResourceService from openviking.service.search_service import SearchService from openviking.service.session_service import SessionService +from openviking.service.task_tracker import set_task_tracker from openviking.session import SessionCompressor, create_session_compressor from openviking.storage import VikingDBManager from openviking.storage.collection_schemas import init_context_collection @@ -149,6 +150,7 @@ def _init_storage( lock_expire=tx_cfg.lock_expire, redo_recovery_enabled=tx_cfg.redo_recovery_enabled, ) + set_task_tracker(config.build_task_tracker(self._agfs_client)) @property def _agfs(self) -> Any: diff --git a/openviking/service/reindex_executor.py b/openviking/service/reindex_executor.py index c4b84dad7..26b534a64 100644 --- a/openviking/service/reindex_executor.py +++ b/openviking/service/reindex_executor.py @@ -376,7 +376,7 @@ async def _run_tracked( ctx: RequestContext, ) -> None: tracker = get_task_tracker() - tracker.start(task_id) + tracker.start(task_id, owner_account_id=ctx.account_id) try: result = await self._run( uri=uri, @@ -384,9 +384,9 @@ async def _run_tracked( mode=mode, ctx=ctx, ) - tracker.complete(task_id, result) + tracker.complete(task_id, result, owner_account_id=ctx.account_id) except Exception as exc: - tracker.fail(task_id, str(exc)) + tracker.fail(task_id, str(exc), owner_account_id=ctx.account_id) async def _reindex_resource( self, diff --git a/openviking/service/resource_service.py b/openviking/service/resource_service.py index 72aa043a7..8c7dc4b62 100644 --- a/openviking/service/resource_service.py +++ b/openviking/service/resource_service.py @@ -301,10 +301,20 @@ async def add_resource( result["task_id"] = task.task_id if telemetry_id: monitor_started = True - asyncio.create_task(self._monitor_queue_processing(task.task_id, telemetry_id)) + asyncio.create_task( + self._monitor_queue_processing( + task.task_id, + telemetry_id, + ctx.account_id, + ) + ) else: - task_tracker.start(task.task_id) - task_tracker.complete(task.task_id, {"root_uri": root_uri}) + task_tracker.start(task.task_id, owner_account_id=ctx.account_id) + task_tracker.complete( + task.task_id, + {"root_uri": root_uri}, + owner_account_id=ctx.account_id, + ) return result except Exception as exc: telemetry.set_error( @@ -322,22 +332,32 @@ async def add_resource( get_request_wait_tracker().cleanup(telemetry_id) unregister_wait_telemetry(telemetry_id) - async def _monitor_queue_processing(self, task_id: str, telemetry_id: str) -> None: + async def _monitor_queue_processing( + self, task_id: str, telemetry_id: str, owner_account_id: str + ) -> None: from openviking.service.task_tracker import get_task_tracker task_tracker = get_task_tracker() request_wait_tracker = get_request_wait_tracker() - task_tracker.start(task_id) + task_tracker.start(task_id, owner_account_id=owner_account_id) try: await request_wait_tracker.wait_for_request(telemetry_id) status = request_wait_tracker.build_queue_status(telemetry_id) errors = sum(int(group.get("error_count", 0) or 0) for group in status.values()) if errors: - task_tracker.fail(task_id, f"queue processing failed: {status}") + task_tracker.fail( + task_id, + f"queue processing failed: {status}", + owner_account_id=owner_account_id, + ) else: - task_tracker.complete(task_id, {"queue_status": status}) + task_tracker.complete( + task_id, + {"queue_status": status}, + owner_account_id=owner_account_id, + ) except Exception as exc: - task_tracker.fail(task_id, str(exc)) + task_tracker.fail(task_id, str(exc), owner_account_id=owner_account_id) finally: request_wait_tracker.cleanup(telemetry_id) unregister_wait_telemetry(telemetry_id) @@ -524,10 +544,16 @@ async def add_skill( result["task_id"] = task.task_id if telemetry_id: monitor_started = True - asyncio.create_task(self._monitor_queue_processing(task.task_id, telemetry_id)) + asyncio.create_task( + self._monitor_queue_processing( + task.task_id, + telemetry_id, + ctx.account_id, + ) + ) else: - task_tracker.start(task.task_id) - task_tracker.complete(task.task_id, {}) + task_tracker.start(task.task_id, owner_account_id=ctx.account_id) + task_tracker.complete(task.task_id, {}, owner_account_id=ctx.account_id) return result finally: diff --git a/openviking/service/session_service.py b/openviking/service/session_service.py index 7fb85d116..c0baee4ab 100644 --- a/openviking/service/session_service.py +++ b/openviking/service/session_service.py @@ -253,6 +253,7 @@ async def get_commit_task(self, task_id: str, ctx: RequestContext) -> Optional[D task = get_task_tracker().get( task_id, owner_account_id=ctx.account_id, + owner_user_id=ctx.user.user_id, ) return task.to_dict() if task else None diff --git a/openviking/service/task_store.py b/openviking/service/task_store.py new file mode 100644 index 000000000..aa7579beb --- /dev/null +++ b/openviking/service/task_store.py @@ -0,0 +1,164 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 +"""Internal storage backends for TaskTracker.""" + +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Any, Dict, List, Optional, Protocol + +from openviking.pyagfs.exceptions import AGFSAlreadyExistsError + + +class TaskStore(Protocol): + def create(self, task: Any) -> None: ... + + def update(self, task: Any) -> None: ... + + def get( + self, task_id: str, *, owner_account_id: Optional[str] = None + ) -> Optional[Dict[str, Any]]: ... + + def list(self, owner_account_id: str) -> List[Dict[str, Any]]: ... + + def delete(self, task_id: str, *, owner_account_id: str) -> None: ... + + +class InMemoryTaskStore: + """Simple in-process task store.""" + + def __init__(self) -> None: + self._tasks: Dict[str, Dict[str, Any]] = {} + + def create(self, task: Any) -> None: + self._tasks[task.task_id] = _task_to_payload(task) + + def update(self, task: Any) -> None: + self._tasks[task.task_id] = _task_to_payload(task) + + def get( + self, task_id: str, *, owner_account_id: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + payload = self._tasks.get(task_id) + if payload is None: + return None + if owner_account_id is not None and payload.get("owner_account_id") != owner_account_id: + return None + return deepcopy(payload) + + def list(self, owner_account_id: str) -> List[Dict[str, Any]]: + return [ + deepcopy(payload) + for payload in self._tasks.values() + if payload.get("owner_account_id") == owner_account_id + ] + + def delete(self, task_id: str, *, owner_account_id: str) -> None: + payload = self._tasks.get(task_id) + if payload and payload.get("owner_account_id") == owner_account_id: + del self._tasks[task_id] + + +class PersistentTaskStore: + """Persist task records into AGFS under account-scoped task directories.""" + + ROOT_PREFIX = "/local" + RESERVED_DIRNAME = "tasks" + + def __init__(self, agfs: Any) -> None: + self._agfs = agfs + + def create(self, task: Any) -> None: + self._write_task(task) + + def update(self, task: Any) -> None: + self._write_task(task) + + def get( + self, task_id: str, *, owner_account_id: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + if not owner_account_id: + return None + path = self._task_path(owner_account_id, task_id) + try: + raw = self._agfs.read(path) + except Exception: + return None + return json.loads(_decode_bytes(raw)) + + def list(self, owner_account_id: str) -> List[Dict[str, Any]]: + directory = self._task_dir(owner_account_id) + try: + items = self._agfs.ls(directory) + except Exception: + return [] + tasks: List[Dict[str, Any]] = [] + for item in items: + path = item.get("path") or f"{directory}/{item.get('name', '')}" + if not path.endswith(".json"): + continue + try: + raw = self._agfs.read(path) + tasks.append(json.loads(_decode_bytes(raw))) + except Exception: + continue + return tasks + + def delete(self, task_id: str, *, owner_account_id: str) -> None: + self._agfs.rm(self._task_path(owner_account_id, task_id), force=True) + + def _write_task(self, task: Any) -> None: + account_id = getattr(task, "owner_account_id", None) + if not account_id: + raise ValueError("PersistentTaskStore requires owner_account_id") + self._ensure_task_dir(account_id) + self._agfs.write( + self._task_path(account_id, task.task_id), + json.dumps(_task_to_payload(task), ensure_ascii=False).encode("utf-8"), + ) + + def _ensure_task_dir(self, account_id: str) -> None: + self._mkdir_if_missing(self._account_dir(account_id)) + self._mkdir_if_missing(self._task_dir(account_id)) + + def _mkdir_if_missing(self, path: str) -> None: + try: + self._agfs.mkdir(path) + except AGFSAlreadyExistsError: + return + except Exception as exc: + if "already exists" in str(exc).lower(): + return + raise + + def _account_dir(self, account_id: str) -> str: + return f"{self.ROOT_PREFIX}/{account_id}" + + def _task_dir(self, account_id: str) -> str: + return f"{self._account_dir(account_id)}/{self.RESERVED_DIRNAME}" + + def _task_path(self, account_id: str, task_id: str) -> str: + return f"{self._task_dir(account_id)}/{task_id}.json" + + +def _task_to_payload(task: Any) -> Dict[str, Any]: + status = getattr(task, "status", None) + return { + "task_id": task.task_id, + "task_type": task.task_type, + "status": status.value if hasattr(status, "value") else status, + "created_at": task.created_at, + "updated_at": task.updated_at, + "resource_id": task.resource_id, + "owner_account_id": task.owner_account_id, + "owner_user_id": task.owner_user_id, + "result": deepcopy(task.result), + "error": task.error, + } + + +def _decode_bytes(raw: Any) -> str: + if isinstance(raw, bytes): + return raw.decode("utf-8") + return str(raw) diff --git a/openviking/service/task_tracker.py b/openviking/service/task_tracker.py index 27bdad1e7..8c847b1d7 100644 --- a/openviking/service/task_tracker.py +++ b/openviking/service/task_tracker.py @@ -3,14 +3,13 @@ """ Async Task Tracker for OpenViking. -Provides a lightweight, in-memory registry for tracking background operations +Provides a lightweight registry for tracking background operations (e.g. session commit with wait=false). Callers receive a task_id that can be polled via the /tasks API to check completion status, results, or errors. Design decisions: - - v1 is pure in-memory (no persistence). Tasks are lost on restart. - Thread-safe (QueueManager workers run in separate threads). - - TTL-based cleanup prevents unbounded memory growth. + - TTL-based cleanup still applies to in-memory cache. - Error messages are sanitized to avoid leaking sensitive data. """ @@ -25,6 +24,7 @@ from typing import Any, Dict, List, Optional from uuid import uuid4 +from openviking.service.task_store import InMemoryTaskStore, TaskStore from openviking_cli.utils.logger import get_logger logger = get_logger(__name__) @@ -81,6 +81,13 @@ def get_task_tracker() -> "TaskTracker": return _instance +def set_task_tracker(tracker: "TaskTracker") -> None: + """Replace the global TaskTracker singleton.""" + global _instance + with _init_lock: + _instance = tracker + + def reset_task_tracker() -> None: """Reset singleton (for testing).""" global _instance @@ -109,7 +116,7 @@ def _sanitize_error(error: str) -> str: class TaskTracker: - """In-memory async task tracker with TTL-based cleanup. + """Async task tracker with pluggable storage and in-memory compatibility cache. Thread-safe: all mutations go through ``_lock``. """ @@ -119,11 +126,16 @@ class TaskTracker: TTL_FAILED = 604_800 # 7 days CLEANUP_INTERVAL = 300 # 5 minutes - def __init__(self) -> None: + def __init__(self, store: Optional[TaskStore] = None) -> None: + self._store = store or InMemoryTaskStore() self._tasks: Dict[str, TaskRecord] = {} self._lock = threading.Lock() self._cleanup_task: Optional[asyncio.Task] = None - logger.info("[TaskTracker] Initialized (in-memory, max_tasks=%d)", self.MAX_TASKS) + logger.info( + "[TaskTracker] Initialized (store=%s, max_tasks=%d)", + self._store.__class__.__name__, + self.MAX_TASKS, + ) # ── Lifecycle ── @@ -165,6 +177,9 @@ def _evict_expired(self) -> None: elif t.status == TaskStatus.FAILED and (now - t.updated_at) > self.TTL_FAILED: to_delete.append(tid) for tid in to_delete: + task = self._tasks[tid] + if isinstance(self._store, InMemoryTaskStore) and task.owner_account_id: + self._store.delete(tid, owner_account_id=task.owner_account_id) del self._tasks[tid] # FIFO eviction if still over limit @@ -172,6 +187,9 @@ def _evict_expired(self) -> None: sorted_tasks = sorted(self._tasks.items(), key=lambda x: x[1].created_at) excess = len(self._tasks) - self.MAX_TASKS for tid, _ in sorted_tasks[:excess]: + task = self._tasks[tid] + if isinstance(self._store, InMemoryTaskStore) and task.owner_account_id: + self._store.delete(tid, owner_account_id=task.owner_account_id) del self._tasks[tid] if to_delete: @@ -217,6 +235,7 @@ def create( ) with self._lock: self._tasks[task.task_id] = task + self._store.create(task) logger.debug( "[TaskTracker] Created task %s type=%s resource=%s", task.task_id, @@ -259,6 +278,7 @@ def create_if_no_running( owner_user_id=owner_user_id, ) self._tasks[task.task_id] = task + self._store.create(task) logger.debug( "[TaskTracker] Created task %s type=%s resource=%s", task.task_id, @@ -267,32 +287,43 @@ def create_if_no_running( ) return self._copy(task) - def start(self, task_id: str) -> None: + def start(self, task_id: str, owner_account_id: Optional[str] = None) -> None: """Transition task to RUNNING.""" with self._lock: - task = self._tasks.get(task_id) + task = self._load_for_update(task_id, owner_account_id) if task: task.status = TaskStatus.RUNNING task.updated_at = time.time() + self._tasks[task.task_id] = task + self._store.update(task) - def complete(self, task_id: str, result: Optional[Dict[str, Any]] = None) -> None: + def complete( + self, + task_id: str, + result: Optional[Dict[str, Any]] = None, + owner_account_id: Optional[str] = None, + ) -> None: """Transition task to COMPLETED with optional result.""" with self._lock: - task = self._tasks.get(task_id) + task = self._load_for_update(task_id, owner_account_id) if task: task.status = TaskStatus.COMPLETED task.result = result task.updated_at = time.time() + self._tasks[task.task_id] = task + self._store.update(task) logger.info("[TaskTracker] Task %s completed", task_id) - def fail(self, task_id: str, error: str) -> None: + def fail(self, task_id: str, error: str, owner_account_id: Optional[str] = None) -> None: """Transition task to FAILED with sanitized error.""" with self._lock: - task = self._tasks.get(task_id) + task = self._load_for_update(task_id, owner_account_id) if task: task.status = TaskStatus.FAILED task.error = _sanitize_error(error) task.updated_at = time.time() + self._tasks[task.task_id] = task + self._store.update(task) logger.warning("[TaskTracker] Task %s failed: %s", task_id, _sanitize_error(error)) def get( @@ -304,6 +335,10 @@ def get( """Look up a single task. Returns a snapshot copy (None if not found).""" with self._lock: task = self._tasks.get(task_id) + if task is None and owner_account_id is not None: + task = self._load_from_store(task_id, owner_account_id) + if task is not None: + self._tasks[task.task_id] = task if task is None or not self._matches_owner(task, owner_account_id, owner_user_id): return None return self._copy(task) @@ -319,6 +354,9 @@ def list_tasks( ) -> List[TaskRecord]: """List tasks with optional filters. Most-recent first. Returns snapshot copies.""" with self._lock: + if owner_account_id is not None: + for task in self._load_all_from_store(owner_account_id): + self._tasks[task.task_id] = task tasks = [ self._copy(t) for t in self._tasks.values() @@ -342,6 +380,9 @@ def has_running( ) -> bool: """Check if there is already a running task for the given type+resource.""" with self._lock: + if owner_account_id is not None: + for task in self._load_all_from_store(owner_account_id): + self._tasks[task.task_id] = task return any( t.task_type == task_type and t.resource_id == resource_id @@ -350,6 +391,33 @@ def has_running( for t in self._tasks.values() ) + def _load_for_update( + self, task_id: str, owner_account_id: Optional[str] + ) -> Optional[TaskRecord]: + task = self._tasks.get(task_id) + if task is not None: + return task + if owner_account_id is None: + return None + return self._load_from_store(task_id, owner_account_id) + + @staticmethod + def _record_from_payload(payload: Dict[str, Any]) -> TaskRecord: + data = dict(payload) + data["status"] = TaskStatus(data["status"]) + return TaskRecord(**data) + + def _load_from_store(self, task_id: str, owner_account_id: str) -> Optional[TaskRecord]: + payload = self._store.get(task_id, owner_account_id=owner_account_id) + if payload is None: + return None + return self._record_from_payload(payload) + + def _load_all_from_store(self, owner_account_id: str) -> List[TaskRecord]: + return [ + self._record_from_payload(payload) for payload in self._store.list(owner_account_id) + ] + @staticmethod def _copy(task: TaskRecord) -> TaskRecord: """Return a defensive copy of a TaskRecord.""" diff --git a/openviking/session/session.py b/openviking/session/session.py index e5bacb702..15b2e0ce1 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -773,10 +773,11 @@ async def _run_memory_extraction( task_id, f"Previous archive archive_{archive_index - 1:03d} failed; " "cannot continue session commit", + owner_account_id=self.ctx.account_id, ) return - tracker.start(task_id) + tracker.start(task_id, owner_account_id=self.ctx.account_id) request_wait_tracker.register_request(telemetry.telemetry_id) register_telemetry(telemetry) try: @@ -979,6 +980,7 @@ async def _noop_agent(): }, }, }, + owner_account_id=self.ctx.account_id, ) logger.info(f"Session {self.session_id} memory extraction completed") except Exception as e: @@ -989,7 +991,7 @@ async def _noop_agent(): stage="memory_extraction", error=str(e), ) - tracker.fail(task_id, str(e)) + tracker.fail(task_id, str(e), owner_account_id=self.ctx.account_id) logger.exception(f"Memory extraction failed for session {self.session_id}") async def _write_done_file( diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index c5837e762..9bf623e89 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -1629,7 +1629,7 @@ def _uri_to_path(self, uri: str, ctx: Optional[RequestContext] = None) -> str: safe_parts = [self._shorten_component(p, self._MAX_FILENAME_BYTES) for p in parts] return f"/local/{account_id}/{'/'.join(safe_parts)}" - _INTERNAL_NAMES = {"_system", ".path.ovlock"} + _INTERNAL_NAMES = {"_system", "tasks", ".path.ovlock"} _ROOT_PATH = "/local" def _ls_entries(self, path: str) -> List[Dict[str, Any]]: diff --git a/openviking_cli/utils/config/storage_config.py b/openviking_cli/utils/config/storage_config.py index 59aa2430c..1ea24a61a 100644 --- a/openviking_cli/utils/config/storage_config.py +++ b/openviking_cli/utils/config/storage_config.py @@ -1,7 +1,7 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: AGPL-3.0 from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Literal from pydantic import BaseModel, Field, model_validator @@ -14,6 +14,15 @@ logger = get_logger(__name__) +class TaskTrackerConfig(BaseModel): + """Configuration for async task tracking backend.""" + + backend: Literal["memory", "persistent"] = Field( + default="memory", + description="Task tracker backend. 'persistent' enables cross-instance task lookup.", + ) + + class StorageConfig(BaseModel): """Configuration for storage backend. @@ -36,6 +45,11 @@ class StorageConfig(BaseModel): description="VectorDB backend configuration", ) + task_tracker: TaskTrackerConfig = Field( + default_factory=TaskTrackerConfig, + description="Task tracker backend configuration", + ) + params: Dict[str, Any] = Field( default_factory=dict, description="Additional storage-specific parameters" ) @@ -75,3 +89,12 @@ def get_upload_temp_dir(self) -> Path: upload_temp_dir = workspace_path / "temp" / "upload" upload_temp_dir.mkdir(parents=True, exist_ok=True) return upload_temp_dir + + def build_task_tracker(self, agfs: Any): + """Build a TaskTracker from storage config.""" + from openviking.service.task_store import PersistentTaskStore + from openviking.service.task_tracker import TaskTracker + + if self.task_tracker.backend == "memory": + return TaskTracker() + return TaskTracker(store=PersistentTaskStore(agfs)) diff --git a/tests/misc/test_vikingfs_uri_guard.py b/tests/misc/test_vikingfs_uri_guard.py index febdba1c2..36382a902 100644 --- a/tests/misc/test_vikingfs_uri_guard.py +++ b/tests/misc/test_vikingfs_uri_guard.py @@ -194,3 +194,15 @@ async def test_grep_with_agfs_maps_dot_match_to_query_root_uri(self) -> None: assert result["matches"][0]["uri"] == "viking://resources/test-root" assert result["matches"][0]["line"] == 1 assert result["matches"][0]["content"] == "act" + + def test_ls_entries_hides_reserved_tasks_dir_under_account_root(self) -> None: + fs = _make_viking_fs() + fs.agfs.ls.return_value = [ + {"name": "resources", "isDir": True}, + {"name": "tasks", "isDir": True}, + {"name": "_system", "isDir": True}, + ] + + entries = fs._ls_entries("/local/default") + + assert [entry["name"] for entry in entries] == ["resources"] diff --git a/tests/test_task_backend_config.py b/tests/test_task_backend_config.py new file mode 100644 index 000000000..10d41b0fe --- /dev/null +++ b/tests/test_task_backend_config.py @@ -0,0 +1,45 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 + +from openviking.service.task_store import InMemoryTaskStore, PersistentTaskStore +from openviking.service.task_tracker import TaskTracker +from openviking_cli.utils.config.storage_config import StorageConfig + + +class _FakeAgfs: + def mkdir(self, path: str, mode: str = "755"): + return {"message": "created", "mode": mode} + + def write(self, path: str, data): + return "OK" + + def read(self, path: str, offset: int = 0, size: int = -1, stream: bool = False): + raise FileNotFoundError(path) + + def ls(self, path: str = "/"): + return [] + + def rm(self, path: str, recursive: bool = False, force: bool = True): + return {"message": "deleted"} + + +def test_storage_config_defaults_task_backend_to_memory(): + config = StorageConfig() + assert config.task_tracker.backend == "memory" + + +def test_storage_config_accepts_memory_task_backend(): + config = StorageConfig(task_tracker={"backend": "memory"}) + assert config.task_tracker.backend == "memory" + + +def test_storage_config_builds_memory_task_tracker(): + tracker = StorageConfig(task_tracker={"backend": "memory"}).build_task_tracker(_FakeAgfs()) + assert isinstance(tracker, TaskTracker) + assert isinstance(tracker._store, InMemoryTaskStore) + + +def test_storage_config_builds_persistent_task_tracker(): + tracker = StorageConfig(task_tracker={"backend": "persistent"}).build_task_tracker(_FakeAgfs()) + assert isinstance(tracker, TaskTracker) + assert isinstance(tracker._store, PersistentTaskStore) diff --git a/tests/test_task_tracker.py b/tests/test_task_tracker.py index e3eaf597f..76b0c9a00 100644 --- a/tests/test_task_tracker.py +++ b/tests/test_task_tracker.py @@ -3,12 +3,15 @@ """Unit tests for TaskTracker.""" +import json import time import pytest +from openviking.pyagfs.exceptions import AGFSAlreadyExistsError from openviking.server.identity import RequestContext, Role from openviking.service.session_service import SessionService +from openviking.service.task_store import InMemoryTaskStore, PersistentTaskStore from openviking.service.task_tracker import ( TaskStatus, TaskTracker, @@ -46,6 +49,60 @@ def _make_ctx(account_id: str = "acme", user_id: str = "alice") -> RequestContex ) +class _FakeAgfs: + def __init__(self): + self.files = {} + self.dirs = {"/", "/local"} + + def mkdir(self, path: str, mode: str = "755"): + self.dirs.add(path.rstrip("/") or "/") + return {"message": "created", "mode": mode} + + def write(self, path: str, data): + if isinstance(data, str): + data = data.encode("utf-8") + self.files[path] = data + parent = path.rsplit("/", 1)[0] or "/" + self.dirs.add(parent) + return "OK" + + def read(self, path: str, offset: int = 0, size: int = -1, stream: bool = False): + if path not in self.files: + raise FileNotFoundError(path) + data = self.files[path] + if size >= 0: + return data[offset : offset + size] + return data[offset:] + + def ls(self, path: str = "/"): + prefix = path.rstrip("/") or "/" + if prefix not in self.dirs: + return [] + children = {} + for directory in self.dirs: + if directory in {prefix, "/"}: + continue + if directory.startswith(prefix + "/"): + name = directory[len(prefix) + 1 :].split("/", 1)[0] + if name: + children[name] = {"name": name, "path": f"{prefix}/{name}", "is_dir": True} + for file_path in self.files: + if file_path.startswith(prefix + "/"): + name = file_path[len(prefix) + 1 :].split("/", 1)[0] + if name and "/" not in file_path[len(prefix) + 1 :]: + children[name] = {"name": name, "path": f"{prefix}/{name}", "is_dir": False} + return list(children.values()) + + +class _FakeAgfsExistingDir(_FakeAgfs): + def mkdir(self, path: str, mode: str = "755"): + normalized = path.rstrip("/") or "/" + if normalized in self.dirs: + raise AGFSAlreadyExistsError(f"already exists: {path}") + self.dirs.add(normalized) + return {"message": "created", "mode": mode} + + # ── Basic CRUD ── @@ -328,6 +385,71 @@ def test_singleton_reset(): assert t1 is not t2 +def test_persistent_store_cross_tracker_visibility(): + agfs = _FakeAgfs() + store = PersistentTaskStore(agfs) + tracker1 = TaskTracker(store=store) + tracker2 = TaskTracker(store=store) + + task = tracker1.create("session_commit", resource_id="sess-123", **_owner_kwargs()) + tracker1.start(task.task_id, owner_account_id="acme") + tracker1.complete(task.task_id, {"ok": True}, owner_account_id="acme") + + loaded = tracker2.get(task.task_id, owner_account_id="acme", owner_user_id="alice") + + assert loaded is not None + assert loaded.status == TaskStatus.COMPLETED + assert loaded.result == {"ok": True} + + +def test_persistent_store_writes_task_record_json(): + agfs = _FakeAgfs() + store = PersistentTaskStore(agfs) + tracker = TaskTracker(store=store) + + task = tracker.create("add_resource", resource_id="viking://resources/demo", **_owner_kwargs()) + + raw = agfs.files[f"/local/acme/tasks/{task.task_id}.json"] + payload = json.loads(raw.decode("utf-8")) + + assert payload["task_id"] == task.task_id + assert payload["task_type"] == "add_resource" + assert payload["owner_account_id"] == "acme" + assert payload["owner_user_id"] == "alice" + assert "schema_version" not in payload + + +def test_inmemory_store_keeps_tasktracker_tasks_dict(): + tracker = TaskTracker(store=InMemoryTaskStore()) + task = tracker.create("session_commit", **_owner_kwargs()) + assert task.task_id in tracker._tasks + + +def test_persistent_store_survives_tracker_reset(): + agfs = _FakeAgfs() + tracker1 = TaskTracker(store=PersistentTaskStore(agfs)) + task = tracker1.create("session_commit", resource_id="sess-123", **_owner_kwargs()) + tracker1.start(task.task_id, owner_account_id="acme") + + tracker2 = TaskTracker(store=PersistentTaskStore(agfs)) + loaded = tracker2.get(task.task_id, owner_account_id="acme", owner_user_id="alice") + + assert loaded is not None + assert loaded.status == TaskStatus.RUNNING + + +def test_persistent_store_ignores_existing_task_dirs(): + agfs = _FakeAgfsExistingDir() + tracker = TaskTracker(store=PersistentTaskStore(agfs)) + + first = tracker.create("session_commit", resource_id="sess-1", **_owner_kwargs()) + second = tracker.create("session_commit", resource_id="sess-2", **_owner_kwargs()) + + assert first.task_id != second.task_id + assert agfs.files[f"/local/acme/tasks/{first.task_id}.json"] + assert agfs.files[f"/local/acme/tasks/{second.task_id}.json"] + + def test_create_requires_owner(tracker: TaskTracker): with pytest.raises(TypeError): tracker.create("session_commit", resource_id="sess-123") @@ -361,3 +483,17 @@ async def test_session_service_get_commit_task_is_owner_scoped(): assert owner_result["task_id"] == task.task_id assert owner_result["resource_id"] == "sess-123" assert other_result is None + + +@pytest.mark.asyncio +async def test_session_service_get_commit_task_also_filters_account(): + tracker = get_task_tracker() + task = tracker.create("session_commit", resource_id="sess-123", **_owner_kwargs()) + service = SessionService() + + other_account_result = await service.get_commit_task( + task.task_id, + _make_ctx(account_id="other-acme", user_id="alice"), + ) + + assert other_account_result is None From 65c9b605d5794e3e75a69a76db6f9ab50f06ee29 Mon Sep 17 00:00:00 2001 From: "zhoujiahui.01" Date: Sat, 9 May 2026 20:05:21 +0800 Subject: [PATCH 2/2] refactor: scope persistent task paths by user --- openviking/server/routers/tasks.py | 8 +- openviking/service/reindex_executor.py | 14 +-- openviking/service/resource_service.py | 43 ++++++--- openviking/service/session_service.py | 4 +- openviking/service/task_store.py | 84 ++++++++++------ openviking/service/task_tracker.py | 129 ++++++++++++++----------- openviking/session/session.py | 16 +-- tests/server/test_auth.py | 8 +- tests/test_session_task_tracking.py | 88 +++++++++++++---- tests/test_task_tracker.py | 58 +++++------ 10 files changed, 284 insertions(+), 168 deletions(-) diff --git a/openviking/server/routers/tasks.py b/openviking/server/routers/tasks.py index a7247b604..f712bd8a9 100644 --- a/openviking/server/routers/tasks.py +++ b/openviking/server/routers/tasks.py @@ -29,8 +29,8 @@ async def get_task( tracker = get_task_tracker() task = tracker.get( task_id, - owner_account_id=_ctx.account_id, - owner_user_id=_ctx.user.user_id, + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, ) if not task: raise OpenVikingError( @@ -58,7 +58,7 @@ async def list_tasks( status=status, resource_id=resource_id, limit=limit, - owner_account_id=_ctx.account_id, - owner_user_id=_ctx.user.user_id, + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, ) return Response(status="ok", result=[t.to_dict() for t in tasks]) diff --git a/openviking/service/reindex_executor.py b/openviking/service/reindex_executor.py index 26b534a64..334414f7b 100644 --- a/openviking/service/reindex_executor.py +++ b/openviking/service/reindex_executor.py @@ -89,8 +89,8 @@ async def execute( if tracker.has_running( REINDEX_TASK_TYPE, uri, - owner_account_id=ctx.account_id, - owner_user_id=ctx.user.user_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ): raise OpenVikingError( f"URI {uri} already has a reindex in progress", @@ -107,8 +107,8 @@ async def execute( task = tracker.create_if_no_running( REINDEX_TASK_TYPE, uri, - owner_account_id=ctx.account_id, - owner_user_id=ctx.user.user_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ) if task is None: raise OpenVikingError( @@ -376,7 +376,7 @@ async def _run_tracked( ctx: RequestContext, ) -> None: tracker = get_task_tracker() - tracker.start(task_id, owner_account_id=ctx.account_id) + tracker.start(task_id, account_id=ctx.account_id, user_id=ctx.user.user_id) try: result = await self._run( uri=uri, @@ -384,9 +384,9 @@ async def _run_tracked( mode=mode, ctx=ctx, ) - tracker.complete(task_id, result, owner_account_id=ctx.account_id) + tracker.complete(task_id, result, account_id=ctx.account_id, user_id=ctx.user.user_id) except Exception as exc: - tracker.fail(task_id, str(exc), owner_account_id=ctx.account_id) + tracker.fail(task_id, str(exc), account_id=ctx.account_id, user_id=ctx.user.user_id) async def _reindex_resource( self, diff --git a/openviking/service/resource_service.py b/openviking/service/resource_service.py index 8c7dc4b62..a38afb570 100644 --- a/openviking/service/resource_service.py +++ b/openviking/service/resource_service.py @@ -295,8 +295,8 @@ async def add_resource( task = task_tracker.create( "add_resource", resource_id=root_uri, - owner_account_id=ctx.account_id, - owner_user_id=ctx.user.user_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ) result["task_id"] = task.task_id if telemetry_id: @@ -309,11 +309,14 @@ async def add_resource( ) ) else: - task_tracker.start(task.task_id, owner_account_id=ctx.account_id) + task_tracker.start( + task.task_id, account_id=ctx.account_id, user_id=ctx.user.user_id + ) task_tracker.complete( task.task_id, {"root_uri": root_uri}, - owner_account_id=ctx.account_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ) return result except Exception as exc: @@ -333,13 +336,17 @@ async def add_resource( unregister_wait_telemetry(telemetry_id) async def _monitor_queue_processing( - self, task_id: str, telemetry_id: str, owner_account_id: str + self, + task_id: str, + telemetry_id: str, + account_id: str, + user_id: str, ) -> None: from openviking.service.task_tracker import get_task_tracker task_tracker = get_task_tracker() request_wait_tracker = get_request_wait_tracker() - task_tracker.start(task_id, owner_account_id=owner_account_id) + task_tracker.start(task_id, account_id=account_id, user_id=user_id) try: await request_wait_tracker.wait_for_request(telemetry_id) status = request_wait_tracker.build_queue_status(telemetry_id) @@ -348,16 +355,18 @@ async def _monitor_queue_processing( task_tracker.fail( task_id, f"queue processing failed: {status}", - owner_account_id=owner_account_id, + account_id=account_id, + user_id=user_id, ) else: task_tracker.complete( task_id, {"queue_status": status}, - owner_account_id=owner_account_id, + account_id=account_id, + user_id=user_id, ) except Exception as exc: - task_tracker.fail(task_id, str(exc), owner_account_id=owner_account_id) + task_tracker.fail(task_id, str(exc), account_id=account_id, user_id=user_id) finally: request_wait_tracker.cleanup(telemetry_id) unregister_wait_telemetry(telemetry_id) @@ -538,8 +547,8 @@ async def add_skill( task_tracker = get_task_tracker() task = task_tracker.create( "add_skill", - owner_account_id=ctx.account_id, - owner_user_id=ctx.user.user_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ) result["task_id"] = task.task_id if telemetry_id: @@ -549,11 +558,19 @@ async def add_skill( task.task_id, telemetry_id, ctx.account_id, + ctx.user.user_id, ) ) else: - task_tracker.start(task.task_id, owner_account_id=ctx.account_id) - task_tracker.complete(task.task_id, {}, owner_account_id=ctx.account_id) + task_tracker.start( + task.task_id, account_id=ctx.account_id, user_id=ctx.user.user_id + ) + task_tracker.complete( + task.task_id, + {}, + account_id=ctx.account_id, + user_id=ctx.user.user_id, + ) return result finally: diff --git a/openviking/service/session_service.py b/openviking/service/session_service.py index c0baee4ab..e0b9ece58 100644 --- a/openviking/service/session_service.py +++ b/openviking/service/session_service.py @@ -252,8 +252,8 @@ async def get_commit_task(self, task_id: str, ctx: RequestContext) -> Optional[D """Query background commit task status by task_id for the calling owner.""" task = get_task_tracker().get( task_id, - owner_account_id=ctx.account_id, - owner_user_id=ctx.user.user_id, + account_id=ctx.account_id, + user_id=ctx.user.user_id, ) return task.to_dict() if task else None diff --git a/openviking/service/task_store.py b/openviking/service/task_store.py index aa7579beb..770bbd026 100644 --- a/openviking/service/task_store.py +++ b/openviking/service/task_store.py @@ -17,12 +17,16 @@ def create(self, task: Any) -> None: ... def update(self, task: Any) -> None: ... def get( - self, task_id: str, *, owner_account_id: Optional[str] = None + self, + task_id: str, + *, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> Optional[Dict[str, Any]]: ... - def list(self, owner_account_id: str) -> List[Dict[str, Any]]: ... + def list(self, account_id: str, *, user_id: Optional[str] = None) -> List[Dict[str, Any]]: ... - def delete(self, task_id: str, *, owner_account_id: str) -> None: ... + def delete(self, task_id: str, *, account_id: str, user_id: Optional[str] = None) -> None: ... class InMemoryTaskStore: @@ -38,25 +42,36 @@ def update(self, task: Any) -> None: self._tasks[task.task_id] = _task_to_payload(task) def get( - self, task_id: str, *, owner_account_id: Optional[str] = None + self, + task_id: str, + *, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> Optional[Dict[str, Any]]: payload = self._tasks.get(task_id) if payload is None: return None - if owner_account_id is not None and payload.get("owner_account_id") != owner_account_id: + if account_id is not None and payload.get("account_id") != account_id: + return None + if user_id is not None and payload.get("user_id") != user_id: return None return deepcopy(payload) - def list(self, owner_account_id: str) -> List[Dict[str, Any]]: + def list(self, account_id: str, *, user_id: Optional[str] = None) -> List[Dict[str, Any]]: return [ deepcopy(payload) for payload in self._tasks.values() - if payload.get("owner_account_id") == owner_account_id + if payload.get("account_id") == account_id + and (user_id is None or payload.get("user_id") == user_id) ] - def delete(self, task_id: str, *, owner_account_id: str) -> None: + def delete(self, task_id: str, *, account_id: str, user_id: Optional[str] = None) -> None: payload = self._tasks.get(task_id) - if payload and payload.get("owner_account_id") == owner_account_id: + if ( + payload + and payload.get("account_id") == account_id + and (user_id is None or payload.get("user_id") == user_id) + ): del self._tasks[task_id] @@ -76,19 +91,25 @@ def update(self, task: Any) -> None: self._write_task(task) def get( - self, task_id: str, *, owner_account_id: Optional[str] = None + self, + task_id: str, + *, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> Optional[Dict[str, Any]]: - if not owner_account_id: + if not account_id or not user_id: return None - path = self._task_path(owner_account_id, task_id) + path = self._task_path(account_id, user_id, task_id) try: raw = self._agfs.read(path) except Exception: return None return json.loads(_decode_bytes(raw)) - def list(self, owner_account_id: str) -> List[Dict[str, Any]]: - directory = self._task_dir(owner_account_id) + def list(self, account_id: str, *, user_id: Optional[str] = None) -> List[Dict[str, Any]]: + if not user_id: + return [] + directory = self._task_dir(account_id, user_id) try: items = self._agfs.ls(directory) except Exception: @@ -105,22 +126,26 @@ def list(self, owner_account_id: str) -> List[Dict[str, Any]]: continue return tasks - def delete(self, task_id: str, *, owner_account_id: str) -> None: - self._agfs.rm(self._task_path(owner_account_id, task_id), force=True) + def delete(self, task_id: str, *, account_id: str, user_id: Optional[str] = None) -> None: + if not user_id: + return + self._agfs.rm(self._task_path(account_id, user_id, task_id), force=True) def _write_task(self, task: Any) -> None: - account_id = getattr(task, "owner_account_id", None) - if not account_id: - raise ValueError("PersistentTaskStore requires owner_account_id") - self._ensure_task_dir(account_id) + account_id = getattr(task, "account_id", None) + user_id = getattr(task, "user_id", None) + if not account_id or not user_id: + raise ValueError("PersistentTaskStore requires account_id and user_id") + self._ensure_task_dir(account_id, user_id) self._agfs.write( - self._task_path(account_id, task.task_id), + self._task_path(account_id, user_id, task.task_id), json.dumps(_task_to_payload(task), ensure_ascii=False).encode("utf-8"), ) - def _ensure_task_dir(self, account_id: str) -> None: + def _ensure_task_dir(self, account_id: str, user_id: str) -> None: self._mkdir_if_missing(self._account_dir(account_id)) - self._mkdir_if_missing(self._task_dir(account_id)) + self._mkdir_if_missing(self._task_root_dir(account_id)) + self._mkdir_if_missing(self._task_dir(account_id, user_id)) def _mkdir_if_missing(self, path: str) -> None: try: @@ -135,11 +160,14 @@ def _mkdir_if_missing(self, path: str) -> None: def _account_dir(self, account_id: str) -> str: return f"{self.ROOT_PREFIX}/{account_id}" - def _task_dir(self, account_id: str) -> str: + def _task_root_dir(self, account_id: str) -> str: return f"{self._account_dir(account_id)}/{self.RESERVED_DIRNAME}" - def _task_path(self, account_id: str, task_id: str) -> str: - return f"{self._task_dir(account_id)}/{task_id}.json" + def _task_dir(self, account_id: str, user_id: str) -> str: + return f"{self._task_root_dir(account_id)}/{user_id}" + + def _task_path(self, account_id: str, user_id: str, task_id: str) -> str: + return f"{self._task_dir(account_id, user_id)}/{task_id}.json" def _task_to_payload(task: Any) -> Dict[str, Any]: @@ -151,8 +179,8 @@ def _task_to_payload(task: Any) -> Dict[str, Any]: "created_at": task.created_at, "updated_at": task.updated_at, "resource_id": task.resource_id, - "owner_account_id": task.owner_account_id, - "owner_user_id": task.owner_user_id, + "account_id": task.account_id, + "user_id": task.user_id, "result": deepcopy(task.result), "error": task.error, } diff --git a/openviking/service/task_tracker.py b/openviking/service/task_tracker.py index 8c847b1d7..ff65363c5 100644 --- a/openviking/service/task_tracker.py +++ b/openviking/service/task_tracker.py @@ -49,8 +49,8 @@ class TaskRecord: created_at: float = field(default_factory=time.time) updated_at: float = field(default_factory=time.time) resource_id: Optional[str] = None # e.g. session_id - owner_account_id: Optional[str] = None - owner_user_id: Optional[str] = None + account_id: Optional[str] = None + user_id: Optional[str] = None result: Optional[Dict[str, Any]] = None error: Optional[str] = None @@ -60,8 +60,8 @@ def to_dict(self) -> Dict[str, Any]: d["status"] = self.status.value d["created_at_iso"] = datetime.fromtimestamp(self.created_at, tz=timezone.utc).isoformat() d["updated_at_iso"] = datetime.fromtimestamp(self.updated_at, tz=timezone.utc).isoformat() - d.pop("owner_account_id", None) - d.pop("owner_user_id", None) + d.pop("account_id", None) + d.pop("user_id", None) return d @@ -178,8 +178,8 @@ def _evict_expired(self) -> None: to_delete.append(tid) for tid in to_delete: task = self._tasks[tid] - if isinstance(self._store, InMemoryTaskStore) and task.owner_account_id: - self._store.delete(tid, owner_account_id=task.owner_account_id) + if isinstance(self._store, InMemoryTaskStore) and task.account_id: + self._store.delete(tid, account_id=task.account_id, user_id=task.user_id) del self._tasks[tid] # FIFO eviction if still over limit @@ -188,8 +188,8 @@ def _evict_expired(self) -> None: excess = len(self._tasks) - self.MAX_TASKS for tid, _ in sorted_tasks[:excess]: task = self._tasks[tid] - if isinstance(self._store, InMemoryTaskStore) and task.owner_account_id: - self._store.delete(tid, owner_account_id=task.owner_account_id) + if isinstance(self._store, InMemoryTaskStore) and task.account_id: + self._store.delete(tid, account_id=task.account_id, user_id=task.user_id) del self._tasks[tid] if to_delete: @@ -198,21 +198,21 @@ def _evict_expired(self) -> None: @staticmethod def _matches_owner( task: TaskRecord, - owner_account_id: Optional[str] = None, - owner_user_id: Optional[str] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> bool: """Return True when a task belongs to the requested owner filter.""" - if owner_account_id is not None and task.owner_account_id != owner_account_id: + if account_id is not None and task.account_id != account_id: return False - if owner_user_id is not None and task.owner_user_id != owner_user_id: + if user_id is not None and task.user_id != user_id: return False return True @staticmethod - def _validate_owner(owner_account_id: str, owner_user_id: str) -> None: + def _validate_owner(account_id: str, user_id: str) -> None: """Reject ownerless task creation for user-originated background work.""" - if not owner_account_id or not owner_user_id: - raise ValueError("Task ownership requires non-empty owner_account_id and owner_user_id") + if not account_id or not user_id: + raise ValueError("Task ownership requires non-empty account_id and user_id") # ── CRUD ── @@ -221,17 +221,17 @@ def create( task_type: str, resource_id: Optional[str] = None, *, - owner_account_id: str, - owner_user_id: str, + account_id: str, + user_id: str, ) -> TaskRecord: """Register a new pending task. Returns a snapshot copy.""" - self._validate_owner(owner_account_id, owner_user_id) + self._validate_owner(account_id, user_id) task = TaskRecord( task_id=str(uuid4()), task_type=task_type, resource_id=resource_id, - owner_account_id=owner_account_id, - owner_user_id=owner_user_id, + account_id=account_id, + user_id=user_id, ) with self._lock: self._tasks[task.task_id] = task @@ -249,21 +249,21 @@ def create_if_no_running( task_type: str, resource_id: str, *, - owner_account_id: str, - owner_user_id: str, + account_id: str, + user_id: str, ) -> Optional[TaskRecord]: """Atomically check for running tasks and create a new one if none exist. Returns TaskRecord on success, None if a running task already exists. This eliminates the race condition between has_running() and create(). """ - self._validate_owner(owner_account_id, owner_user_id) + self._validate_owner(account_id, user_id) with self._lock: # Check for existing running tasks has_active = any( t.task_type == task_type and t.resource_id == resource_id - and self._matches_owner(t, owner_account_id, owner_user_id) + and self._matches_owner(t, account_id, user_id) and t.status in (TaskStatus.PENDING, TaskStatus.RUNNING) for t in self._tasks.values() ) @@ -274,8 +274,8 @@ def create_if_no_running( task_id=str(uuid4()), task_type=task_type, resource_id=resource_id, - owner_account_id=owner_account_id, - owner_user_id=owner_user_id, + account_id=account_id, + user_id=user_id, ) self._tasks[task.task_id] = task self._store.create(task) @@ -287,10 +287,15 @@ def create_if_no_running( ) return self._copy(task) - def start(self, task_id: str, owner_account_id: Optional[str] = None) -> None: + def start( + self, + task_id: str, + account_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> None: """Transition task to RUNNING.""" with self._lock: - task = self._load_for_update(task_id, owner_account_id) + task = self._load_for_update(task_id, account_id, user_id) if task: task.status = TaskStatus.RUNNING task.updated_at = time.time() @@ -301,11 +306,12 @@ def complete( self, task_id: str, result: Optional[Dict[str, Any]] = None, - owner_account_id: Optional[str] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> None: """Transition task to COMPLETED with optional result.""" with self._lock: - task = self._load_for_update(task_id, owner_account_id) + task = self._load_for_update(task_id, account_id, user_id) if task: task.status = TaskStatus.COMPLETED task.result = result @@ -314,10 +320,16 @@ def complete( self._store.update(task) logger.info("[TaskTracker] Task %s completed", task_id) - def fail(self, task_id: str, error: str, owner_account_id: Optional[str] = None) -> None: + def fail( + self, + task_id: str, + error: str, + account_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> None: """Transition task to FAILED with sanitized error.""" with self._lock: - task = self._load_for_update(task_id, owner_account_id) + task = self._load_for_update(task_id, account_id, user_id) if task: task.status = TaskStatus.FAILED task.error = _sanitize_error(error) @@ -329,17 +341,17 @@ def fail(self, task_id: str, error: str, owner_account_id: Optional[str] = None) def get( self, task_id: str, - owner_account_id: Optional[str] = None, - owner_user_id: Optional[str] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> Optional[TaskRecord]: """Look up a single task. Returns a snapshot copy (None if not found).""" with self._lock: task = self._tasks.get(task_id) - if task is None and owner_account_id is not None: - task = self._load_from_store(task_id, owner_account_id) + if task is None and account_id is not None: + task = self._load_from_store(task_id, account_id, user_id) if task is not None: self._tasks[task.task_id] = task - if task is None or not self._matches_owner(task, owner_account_id, owner_user_id): + if task is None or not self._matches_owner(task, account_id, user_id): return None return self._copy(task) @@ -349,18 +361,18 @@ def list_tasks( status: Optional[str] = None, resource_id: Optional[str] = None, limit: int = 50, - owner_account_id: Optional[str] = None, - owner_user_id: Optional[str] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> List[TaskRecord]: """List tasks with optional filters. Most-recent first. Returns snapshot copies.""" with self._lock: - if owner_account_id is not None: - for task in self._load_all_from_store(owner_account_id): + if account_id is not None: + for task in self._load_all_from_store(account_id, user_id): self._tasks[task.task_id] = task tasks = [ self._copy(t) for t in self._tasks.values() - if self._matches_owner(t, owner_account_id, owner_user_id) + if self._matches_owner(t, account_id, user_id) ] if task_type: tasks = [t for t in tasks if t.task_type == task_type] @@ -375,31 +387,34 @@ def has_running( self, task_type: str, resource_id: str, - owner_account_id: Optional[str] = None, - owner_user_id: Optional[str] = None, + account_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> bool: """Check if there is already a running task for the given type+resource.""" with self._lock: - if owner_account_id is not None: - for task in self._load_all_from_store(owner_account_id): + if account_id is not None: + for task in self._load_all_from_store(account_id, user_id): self._tasks[task.task_id] = task return any( t.task_type == task_type and t.resource_id == resource_id - and self._matches_owner(t, owner_account_id, owner_user_id) + and self._matches_owner(t, account_id, user_id) and t.status in (TaskStatus.PENDING, TaskStatus.RUNNING) for t in self._tasks.values() ) def _load_for_update( - self, task_id: str, owner_account_id: Optional[str] + self, + task_id: str, + account_id: Optional[str], + user_id: Optional[str], ) -> Optional[TaskRecord]: task = self._tasks.get(task_id) if task is not None: return task - if owner_account_id is None: + if account_id is None or user_id is None: return None - return self._load_from_store(task_id, owner_account_id) + return self._load_from_store(task_id, account_id, user_id) @staticmethod def _record_from_payload(payload: Dict[str, Any]) -> TaskRecord: @@ -407,15 +422,21 @@ def _record_from_payload(payload: Dict[str, Any]) -> TaskRecord: data["status"] = TaskStatus(data["status"]) return TaskRecord(**data) - def _load_from_store(self, task_id: str, owner_account_id: str) -> Optional[TaskRecord]: - payload = self._store.get(task_id, owner_account_id=owner_account_id) + def _load_from_store( + self, + task_id: str, + account_id: str, + user_id: Optional[str], + ) -> Optional[TaskRecord]: + payload = self._store.get(task_id, account_id=account_id, user_id=user_id) if payload is None: return None return self._record_from_payload(payload) - def _load_all_from_store(self, owner_account_id: str) -> List[TaskRecord]: + def _load_all_from_store(self, account_id: str, user_id: Optional[str]) -> List[TaskRecord]: return [ - self._record_from_payload(payload) for payload in self._store.list(owner_account_id) + self._record_from_payload(payload) + for payload in self._store.list(account_id, user_id=user_id) ] @staticmethod diff --git a/openviking/session/session.py b/openviking/session/session.py index 15b2e0ce1..f38cc38e6 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -705,8 +705,8 @@ async def commit_async(self, keep_recent_count: int = 0) -> Dict[str, Any]: task = tracker.create( "session_commit", resource_id=self.session_id, - owner_account_id=self.ctx.account_id, - owner_user_id=self.ctx.user.user_id, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, ) asyncio.create_task( @@ -773,11 +773,12 @@ async def _run_memory_extraction( task_id, f"Previous archive archive_{archive_index - 1:03d} failed; " "cannot continue session commit", - owner_account_id=self.ctx.account_id, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, ) return - tracker.start(task_id, owner_account_id=self.ctx.account_id) + tracker.start(task_id, account_id=self.ctx.account_id, user_id=self.ctx.user.user_id) request_wait_tracker.register_request(telemetry.telemetry_id) register_telemetry(telemetry) try: @@ -980,7 +981,8 @@ async def _noop_agent(): }, }, }, - owner_account_id=self.ctx.account_id, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, ) logger.info(f"Session {self.session_id} memory extraction completed") except Exception as e: @@ -991,7 +993,9 @@ async def _noop_agent(): stage="memory_extraction", error=str(e), ) - tracker.fail(task_id, str(e), owner_account_id=self.ctx.account_id) + tracker.fail( + task_id, str(e), account_id=self.ctx.account_id, user_id=self.ctx.user.user_id + ) logger.exception(f"Memory extraction failed for session {self.session_id}") async def _write_done_file( diff --git a/tests/server/test_auth.py b/tests/server/test_auth.py index c14337501..b2692041d 100644 --- a/tests/server/test_auth.py +++ b/tests/server/test_auth.py @@ -323,14 +323,14 @@ async def test_task_endpoints_are_user_scoped(): alice_task = tracker.create( "session_commit", resource_id="alice-session", - owner_account_id=account_id, - owner_user_id="alice", + account_id=account_id, + user_id="alice", ) bob_task = tracker.create( "session_commit", resource_id="bob-session", - owner_account_id=account_id, - owner_user_id="bob", + account_id=account_id, + user_id="bob", ) alice_app = _build_task_http_test_app( diff --git a/tests/test_session_task_tracking.py b/tests/test_session_task_tracking.py index c8c2d15b3..617580ac7 100644 --- a/tests/test_session_task_tracking.py +++ b/tests/test_session_task_tracking.py @@ -69,13 +69,13 @@ async def mock_commit(_sid, _ctx): task = tracker.create( "session_commit", resource_id=_sid, - owner_account_id=_ctx.account_id, - owner_user_id=_ctx.user.user_id, + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, ) archive_uri = f"viking://session/test/{_sid}/history/archive_001" async def _background(): - tracker.start(task.task_id) + tracker.start(task.task_id, account_id=_ctx.account_id, user_id=_ctx.user.user_id) try: if started: started.set() @@ -96,9 +96,19 @@ async def _background(): } if result_overrides: final_result.update(result_overrides) - tracker.complete(task.task_id, final_result) + tracker.complete( + task.task_id, + final_result, + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, + ) except Exception as e: - tracker.fail(task.task_id, str(e)) + tracker.fail( + task.task_id, + str(e), + account_id=_ctx.account_id, + user_id=_ctx.user.user_id, + ) asyncio.create_task(_background()) @@ -336,11 +346,20 @@ async def fake_add_resource(**kwargs): task = tracker.create( "add_resource", resource_id=root_uri, - owner_account_id=kwargs["ctx"].account_id, - owner_user_id=kwargs["ctx"].user.user_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.start( + task.task_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.complete( + task.task_id, + {"root_uri": root_uri}, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, ) - tracker.start(task.task_id) - tracker.complete(task.task_id, {"root_uri": root_uri}) return {"status": "success", "root_uri": root_uri, "task_id": task.task_id} service.resources.add_resource = fake_add_resource @@ -389,14 +408,23 @@ async def fake_add_resource(**kwargs): task = tracker.create( "add_resource", resource_id=root_uri, - owner_account_id=kwargs["ctx"].account_id, - owner_user_id=kwargs["ctx"].user.user_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, ) async def _background(): - tracker.start(task.task_id) + tracker.start( + task.task_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) await asyncio.sleep(0.05) - tracker.complete(task.task_id, {"root_uri": root_uri}) + tracker.complete( + task.task_id, + {"root_uri": root_uri}, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) asyncio.create_task(_background()) return {"status": "success", "root_uri": root_uri, "task_id": task.task_id} @@ -434,11 +462,20 @@ async def fake_add_resource(**kwargs): task = tracker.create( "add_resource", resource_id=root_uri, - owner_account_id=kwargs["ctx"].account_id, - owner_user_id=kwargs["ctx"].user.user_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.start( + task.task_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.complete( + task.task_id, + {"root_uri": root_uri}, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, ) - tracker.start(task.task_id) - tracker.complete(task.task_id, {"root_uri": root_uri}) return {"status": "success", "root_uri": root_uri, "task_id": task.task_id} service.resources.add_resource = fake_add_resource @@ -469,11 +506,20 @@ async def fake_add_skill(**kwargs): tracker = get_task_tracker() task = tracker.create( "add_skill", - owner_account_id=kwargs["ctx"].account_id, - owner_user_id=kwargs["ctx"].user.user_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.start( + task.task_id, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, + ) + tracker.complete( + task.task_id, + {}, + account_id=kwargs["ctx"].account_id, + user_id=kwargs["ctx"].user.user_id, ) - tracker.start(task.task_id) - tracker.complete(task.task_id, {}) return {"status": "success", "task_id": task.task_id} service.resources.add_skill = fake_add_skill diff --git a/tests/test_task_tracker.py b/tests/test_task_tracker.py index 76b0c9a00..a6b4d7d74 100644 --- a/tests/test_task_tracker.py +++ b/tests/test_task_tracker.py @@ -37,8 +37,8 @@ def tracker() -> TaskTracker: def _owner_kwargs(account_id: str = "acme", user_id: str = "alice"): return { - "owner_account_id": account_id, - "owner_user_id": user_id, + "account_id": account_id, + "user_id": user_id, } @@ -188,15 +188,15 @@ def test_get_hides_task_from_other_owner(tracker: TaskTracker): task = tracker.create( "session_commit", resource_id="s1", - owner_account_id="acme", - owner_user_id="alice", + account_id="acme", + user_id="alice", ) assert ( tracker.get( task.task_id, - owner_account_id="acme", - owner_user_id="bob", + account_id="acme", + user_id="bob", ) is None ) @@ -206,17 +206,17 @@ def test_list_tasks_filters_by_owner(tracker: TaskTracker): tracker.create( "session_commit", resource_id="alice-task", - owner_account_id="acme", - owner_user_id="alice", + account_id="acme", + user_id="alice", ) tracker.create( "session_commit", resource_id="bob-task", - owner_account_id="acme", - owner_user_id="bob", + account_id="acme", + user_id="bob", ) - tasks = tracker.list_tasks(owner_account_id="acme", owner_user_id="alice") + tasks = tracker.list_tasks(account_id="acme", user_id="alice") assert len(tasks) == 1 assert tasks[0].resource_id == "alice-task" @@ -269,14 +269,14 @@ def test_create_if_no_running_isolated_by_owner(tracker: TaskTracker): alice_task = tracker.create_if_no_running( "reindex", "viking://resources/demo", - owner_account_id="acme", - owner_user_id="alice", + account_id="acme", + user_id="alice", ) bob_task = tracker.create_if_no_running( "reindex", "viking://resources/demo", - owner_account_id="acme", - owner_user_id="bob", + account_id="acme", + user_id="bob", ) assert alice_task is not None @@ -303,8 +303,8 @@ def test_to_dict(tracker: TaskTracker): assert isinstance(d["created_at_iso"], str) assert "T" in d["created_at_iso"] assert isinstance(d["updated_at_iso"], str) - assert "owner_account_id" not in d - assert "owner_user_id" not in d + assert "account_id" not in d + assert "user_id" not in d # ── Sanitization ── @@ -392,10 +392,10 @@ def test_persistent_store_cross_tracker_visibility(): tracker2 = TaskTracker(store=store) task = tracker1.create("session_commit", resource_id="sess-123", **_owner_kwargs()) - tracker1.start(task.task_id, owner_account_id="acme") - tracker1.complete(task.task_id, {"ok": True}, owner_account_id="acme") + tracker1.start(task.task_id, account_id="acme", user_id="alice") + tracker1.complete(task.task_id, {"ok": True}, account_id="acme", user_id="alice") - loaded = tracker2.get(task.task_id, owner_account_id="acme", owner_user_id="alice") + loaded = tracker2.get(task.task_id, account_id="acme", user_id="alice") assert loaded is not None assert loaded.status == TaskStatus.COMPLETED @@ -409,13 +409,13 @@ def test_persistent_store_writes_task_record_json(): task = tracker.create("add_resource", resource_id="viking://resources/demo", **_owner_kwargs()) - raw = agfs.files[f"/local/acme/tasks/{task.task_id}.json"] + raw = agfs.files[f"/local/acme/tasks/alice/{task.task_id}.json"] payload = json.loads(raw.decode("utf-8")) assert payload["task_id"] == task.task_id assert payload["task_type"] == "add_resource" - assert payload["owner_account_id"] == "acme" - assert payload["owner_user_id"] == "alice" + assert payload["account_id"] == "acme" + assert payload["user_id"] == "alice" assert "schema_version" not in payload @@ -429,10 +429,10 @@ def test_persistent_store_survives_tracker_reset(): agfs = _FakeAgfs() tracker1 = TaskTracker(store=PersistentTaskStore(agfs)) task = tracker1.create("session_commit", resource_id="sess-123", **_owner_kwargs()) - tracker1.start(task.task_id, owner_account_id="acme") + tracker1.start(task.task_id, account_id="acme", user_id="alice") tracker2 = TaskTracker(store=PersistentTaskStore(agfs)) - loaded = tracker2.get(task.task_id, owner_account_id="acme", owner_user_id="alice") + loaded = tracker2.get(task.task_id, account_id="acme", user_id="alice") assert loaded is not None assert loaded.status == TaskStatus.RUNNING @@ -446,8 +446,8 @@ def test_persistent_store_ignores_existing_task_dirs(): second = tracker.create("session_commit", resource_id="sess-2", **_owner_kwargs()) assert first.task_id != second.task_id - assert agfs.files[f"/local/acme/tasks/{first.task_id}.json"] - assert agfs.files[f"/local/acme/tasks/{second.task_id}.json"] + assert agfs.files[f"/local/acme/tasks/alice/{first.task_id}.json"] + assert agfs.files[f"/local/acme/tasks/alice/{second.task_id}.json"] def test_create_requires_owner(tracker: TaskTracker): @@ -465,8 +465,8 @@ def test_create_rejects_blank_owner_values(tracker: TaskTracker): tracker.create( "session_commit", resource_id="sess-123", - owner_account_id="", - owner_user_id="alice", + account_id="", + user_id="alice", )