diff --git a/openviking/storage/queuefs/embedding_tracker.py b/openviking/storage/queuefs/embedding_tracker.py index a51af64550..60dc85f760 100644 --- a/openviking/storage/queuefs/embedding_tracker.py +++ b/openviking/storage/queuefs/embedding_tracker.py @@ -3,6 +3,9 @@ """Embedding Task Tracker for tracking embedding task completion status.""" import asyncio +import inspect +import threading +from dataclasses import dataclass from typing import Any, Callable, Dict, Optional from openviking_cli.utils.logger import get_logger @@ -10,12 +13,27 @@ logger = get_logger(__name__) +@dataclass +class _EmbeddingTaskRecord: + """Coordinator state for a single semantic message.""" + + remaining: int + total: int + on_complete: Optional[Callable[[], Any]] + metadata: Dict[str, Any] + owner_loop: Optional[asyncio.AbstractEventLoop] + + class EmbeddingTaskTracker: """Track embedding task completion status for each SemanticMsg. - This tracker maintains a global registry of embedding tasks associated - with each SemanticMsg. When all embedding tasks for a SemanticMsg are - completed, it triggers the registered callback and removes the entry. + This tracker maintains a process-global registry of embedding tasks associated + with each SemanticMsg. Because semantic and embedding queues run on separate + worker threads with distinct event loops, its internal state must be guarded + by thread-safe primitives rather than loop-bound asyncio locks. + + When all embedding tasks for a SemanticMsg are completed, it triggers the + registered callback and removes the entry. """ _instance: Optional["EmbeddingTaskTracker"] = None @@ -29,10 +47,72 @@ def __new__(cls) -> "EmbeddingTaskTracker": def __init__(self): if self._initialized: return - self._lock: asyncio.Lock = asyncio.Lock() - self._tasks: Dict[str, Dict[str, Any]] = {} + self._lock = threading.Lock() + self._tasks: Dict[str, _EmbeddingTaskRecord] = {} self._initialized = True + @staticmethod + async def _await_callback_result(result: Any) -> None: + """Await callback results when they are async.""" + if inspect.isawaitable(result): + await result + + async def _execute_callback(self, on_complete: Callable[[], Any]) -> None: + """Invoke a completion callback and await async results.""" + await self._await_callback_result(on_complete()) + + async def _run_on_complete( + self, + semantic_msg_id: str, + record: _EmbeddingTaskRecord, + ) -> None: + """Execute the completion callback on the loop that registered it.""" + on_complete = record.on_complete + owner_loop = record.owner_loop + if on_complete is None: + return + + try: + current_loop = asyncio.get_running_loop() + except RuntimeError: + current_loop = None + + owner_loop_running = bool(owner_loop and owner_loop.is_running()) + owner_loop_available = bool( + owner_loop and not owner_loop.is_closed() and owner_loop_running + ) + + try: + if owner_loop and owner_loop is not current_loop: + if not owner_loop_available: + logger.warning( + "Owner loop unavailable before completion callback for %s; " + "running callback in current loop", + semantic_msg_id, + ) + else: + try: + fut = asyncio.run_coroutine_threadsafe( + self._execute_callback(on_complete), + owner_loop, + ) + except RuntimeError: + logger.warning( + "Owner loop stopped before completion callback for %s; " + "running callback in current loop", + semantic_msg_id, + ) + else: + await asyncio.wrap_future(fut) + return + + await self._execute_callback(on_complete) + except Exception as e: + logger.error( + f"Error in completion callback for {semantic_msg_id}: {e}", + exc_info=True, + ) + @classmethod def get_instance(cls) -> "EmbeddingTaskTracker": """Get the singleton instance of EmbeddingTaskTracker.""" @@ -55,35 +135,38 @@ async def register( on_complete: Optional callback when all tasks complete metadata: Optional metadata to store with the task """ - async with self._lock: - self._tasks[semantic_msg_id] = { - "remaining": total_count, - "total": total_count, - "on_complete": on_complete, - "metadata": metadata or {}, - } + owner_loop = asyncio.get_running_loop() + record_to_finalize: Optional[_EmbeddingTaskRecord] = None + + with self._lock: + existing = self._tasks.get(semantic_msg_id) + if existing is not None: + logger.warning( + "Overwriting existing embedding tracker record for SemanticMsg %s", + semantic_msg_id, + ) + + self._tasks[semantic_msg_id] = _EmbeddingTaskRecord( + remaining=total_count, + total=total_count, + on_complete=on_complete, + metadata=metadata or {}, + owner_loop=owner_loop, + ) logger.info( f"Registered embedding tracker for SemanticMsg {semantic_msg_id}: " f"{total_count} tasks" ) - if total_count <= 0 and on_complete: - del self._tasks[semantic_msg_id] + if total_count <= 0: + record_to_finalize = self._tasks.pop(semantic_msg_id) logger.info( f"No embedding tasks for SemanticMsg {semantic_msg_id}, " - f"triggering on_complete immediately" + f"clearing tracker entry immediately" ) - if total_count <= 0 and on_complete: - try: - result = on_complete() - if asyncio.iscoroutine(result): - await result - except Exception as e: - logger.error( - f"Error in completion callback for {semantic_msg_id}: {e}", - exc_info=True, - ) + if record_to_finalize is not None: + await self._run_on_complete(semantic_msg_id, record_to_finalize) async def decrement(self, semantic_msg_id: str) -> Optional[int]: """Decrement the remaining task count for a SemanticMsg. @@ -98,32 +181,22 @@ async def decrement(self, semantic_msg_id: str) -> Optional[int]: Returns: The remaining count after decrement, or None if not found """ - on_complete = None + record_to_finalize: Optional[_EmbeddingTaskRecord] = None - async with self._lock: - if semantic_msg_id not in self._tasks: + with self._lock: + record = self._tasks.get(semantic_msg_id) + if record is None: return None - task_info = self._tasks[semantic_msg_id] - task_info["remaining"] -= 1 - remaining = task_info["remaining"] + record.remaining -= 1 + remaining = record.remaining if remaining <= 0: - on_complete = task_info.get("on_complete") - - del self._tasks[semantic_msg_id] + record_to_finalize = self._tasks.pop(semantic_msg_id) logger.info( - f"All embedding tasks({task_info['total']}) completed for SemanticMsg {semantic_msg_id}" + f"All embedding tasks({record.total}) completed for SemanticMsg {semantic_msg_id}" ) - if on_complete: - try: - result = on_complete() - if asyncio.iscoroutine(result): - await result - except Exception as e: - logger.error( - f"Error in completion callback for {semantic_msg_id}: {e}", - exc_info=True, - ) + if record_to_finalize is not None: + await self._run_on_complete(semantic_msg_id, record_to_finalize) return remaining diff --git a/tests/storage/test_embedding_tracker.py b/tests/storage/test_embedding_tracker.py new file mode 100644 index 0000000000..1e3db2a9e0 --- /dev/null +++ b/tests/storage/test_embedding_tracker.py @@ -0,0 +1,187 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import concurrent.futures +import threading +import time + +import pytest + +from openviking.storage.queuefs.embedding_tracker import EmbeddingTaskTracker + + +class _LoopThread: + def __init__(self, close_delay: float = 0) -> None: + self.loop = asyncio.new_event_loop() + self._ready = threading.Event() + self._close_delay = close_delay + self.thread = threading.Thread(target=self._run, daemon=True) + self.thread.start() + self._ready.wait(timeout=2) + + def _run(self) -> None: + asyncio.set_event_loop(self.loop) + self._ready.set() + self.loop.run_forever() + if self._close_delay: + time.sleep(self._close_delay) + pending = asyncio.all_tasks(self.loop) + for task in pending: + task.cancel() + if pending: + self.loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + self.loop.close() + + def submit(self, coro): + return asyncio.run_coroutine_threadsafe(coro, self.loop) + + def stop(self) -> None: + if self.loop.is_closed(): + return + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join(timeout=3) + + def stop_without_join(self) -> None: + if self.loop.is_closed(): + return + self.loop.call_soon_threadsafe(self.loop.stop) + + def join(self) -> None: + self.thread.join(timeout=3) + + +@pytest.fixture(autouse=True) +def _reset_tracker_singleton(): + EmbeddingTaskTracker._instance = None + EmbeddingTaskTracker._initialized = False + yield + EmbeddingTaskTracker._instance = None + EmbeddingTaskTracker._initialized = False + + +def test_tracker_runs_completion_callback_on_register_loop(): + tracker = EmbeddingTaskTracker.get_instance() + owner = _LoopThread() + worker = _LoopThread() + callback_info = concurrent.futures.Future() + + async def on_complete(): + callback_info.set_result((threading.get_ident(), asyncio.get_running_loop())) + await asyncio.sleep(0) + + async def register(): + await tracker.register("semantic-msg", 1, on_complete=on_complete) + + async def decrement(): + return await tracker.decrement("semantic-msg") + + try: + owner.submit(register()).result(timeout=2) + assert not callback_info.done() + + remaining = worker.submit(decrement()).result(timeout=2) + callback_thread_id, callback_loop = callback_info.result(timeout=2) + finally: + owner.stop() + worker.stop() + + assert remaining == 0 + assert callback_thread_id == owner.thread.ident + assert callback_loop is owner.loop + + +def test_tracker_falls_back_to_current_loop_when_owner_loop_is_closed(): + tracker = EmbeddingTaskTracker.get_instance() + owner = _LoopThread() + worker = _LoopThread() + callback_info = concurrent.futures.Future() + + async def on_complete(): + callback_info.set_result((threading.get_ident(), asyncio.get_running_loop())) + + async def register(): + await tracker.register("semantic-msg", 1, on_complete=on_complete) + + async def decrement(): + return await tracker.decrement("semantic-msg") + + try: + owner.submit(register()).result(timeout=2) + owner.stop() + + remaining = worker.submit(decrement()).result(timeout=2) + callback_thread_id, callback_loop = callback_info.result(timeout=2) + finally: + worker.stop() + + assert remaining == 0 + assert callback_thread_id == worker.thread.ident + assert callback_loop is worker.loop + + +def test_tracker_falls_back_to_current_loop_when_owner_loop_is_stopped(): + tracker = EmbeddingTaskTracker.get_instance() + owner = _LoopThread(close_delay=1) + worker = _LoopThread() + callback_info = concurrent.futures.Future() + + async def on_complete(): + callback_info.set_result((threading.get_ident(), asyncio.get_running_loop())) + + async def register(): + await tracker.register("semantic-msg", 1, on_complete=on_complete) + + async def decrement(): + return await tracker.decrement("semantic-msg") + + try: + owner.submit(register()).result(timeout=2) + owner.stop_without_join() + time.sleep(0.1) + + remaining = worker.submit(decrement()).result(timeout=2) + callback_thread_id, callback_loop = callback_info.result(timeout=2) + finally: + worker.stop() + owner.join() + + assert remaining == 0 + assert callback_thread_id == worker.thread.ident + assert callback_loop is worker.loop + + +@pytest.mark.asyncio +async def test_tracker_runs_zero_task_callback_immediately(): + tracker = EmbeddingTaskTracker.get_instance() + callback_loop = None + + async def on_complete(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + + await tracker.register("semantic-msg", 0, on_complete=on_complete) + + assert callback_loop is asyncio.get_running_loop() + + +@pytest.mark.asyncio +async def test_tracker_supports_sync_callback_and_missing_task(): + tracker = EmbeddingTaskTracker.get_instance() + callback_calls = [] + + await tracker.register("semantic-msg", 1, on_complete=lambda: callback_calls.append("done")) + remaining = await tracker.decrement("semantic-msg") + + assert remaining == 0 + assert callback_calls == ["done"] + assert await tracker.decrement("missing-semantic-msg") is None + + +@pytest.mark.asyncio +async def test_tracker_clears_zero_task_entry_without_callback(): + tracker = EmbeddingTaskTracker.get_instance() + + await tracker.register("semantic-msg", 0, on_complete=None) + + assert await tracker.decrement("semantic-msg") is None diff --git a/tests/storage/test_semantic_dag_skip_files.py b/tests/storage/test_semantic_dag_skip_files.py index 3eaeaa2f82..9a9be4fb05 100644 --- a/tests/storage/test_semantic_dag_skip_files.py +++ b/tests/storage/test_semantic_dag_skip_files.py @@ -69,7 +69,14 @@ async def _vectorize_directory_simple(self, uri, context_type, abstract, overvie await self._vectorize_directory(uri, context_type, abstract, overview, ctx=ctx) async def _vectorize_single_file( - self, parent_uri, context_type, file_path, summary_dict, ctx=None, semantic_msg_id=None + self, + parent_uri, + context_type, + file_path, + summary_dict, + ctx=None, + semantic_msg_id=None, + use_summary=False, ): self.vectorized_files.append(file_path) diff --git a/tests/storage/test_semantic_dag_stats.py b/tests/storage/test_semantic_dag_stats.py index 94f9441f91..298d142f8b 100644 --- a/tests/storage/test_semantic_dag_stats.py +++ b/tests/storage/test_semantic_dag_stats.py @@ -49,7 +49,14 @@ async def _vectorize_directory( self.vectorized_dirs.append(uri) async def _vectorize_single_file( - self, parent_uri, context_type, file_path, summary_dict, ctx=None, semantic_msg_id=None + self, + parent_uri, + context_type, + file_path, + summary_dict, + ctx=None, + semantic_msg_id=None, + use_summary=False, ): self.vectorized_files.append(file_path)