Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 118 additions & 45 deletions openviking/storage/queuefs/embedding_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,37 @@
"""Embedding Task Tracker for tracking embedding task completion status."""

import asyncio
import inspect
import threading
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional

from openviking_cli.utils.logger import get_logger

logger = get_logger(__name__)


@dataclass
class _EmbeddingTaskRecord:
"""Coordinator state for a single semantic message."""

remaining: int
total: int
on_complete: Optional[Callable[[], Any]]
metadata: Dict[str, Any]
owner_loop: Optional[asyncio.AbstractEventLoop]


class EmbeddingTaskTracker:
"""Track embedding task completion status for each SemanticMsg.

This tracker maintains a global registry of embedding tasks associated
with each SemanticMsg. When all embedding tasks for a SemanticMsg are
completed, it triggers the registered callback and removes the entry.
This tracker maintains a process-global registry of embedding tasks associated
with each SemanticMsg. Because semantic and embedding queues run on separate
worker threads with distinct event loops, its internal state must be guarded
by thread-safe primitives rather than loop-bound asyncio locks.

When all embedding tasks for a SemanticMsg are completed, it triggers the
registered callback and removes the entry.
"""

_instance: Optional["EmbeddingTaskTracker"] = None
Expand All @@ -29,10 +47,72 @@ def __new__(cls) -> "EmbeddingTaskTracker":
def __init__(self):
if self._initialized:
return
self._lock: asyncio.Lock = asyncio.Lock()
self._tasks: Dict[str, Dict[str, Any]] = {}
self._lock = threading.Lock()
self._tasks: Dict[str, _EmbeddingTaskRecord] = {}
self._initialized = True

@staticmethod
async def _await_callback_result(result: Any) -> None:
"""Await callback results when they are async."""
if inspect.isawaitable(result):
await result

async def _execute_callback(self, on_complete: Callable[[], Any]) -> None:
"""Invoke a completion callback and await async results."""
await self._await_callback_result(on_complete())

async def _run_on_complete(
self,
semantic_msg_id: str,
record: _EmbeddingTaskRecord,
) -> None:
"""Execute the completion callback on the loop that registered it."""
on_complete = record.on_complete
owner_loop = record.owner_loop
if on_complete is None:
return

try:
current_loop = asyncio.get_running_loop()
except RuntimeError:
current_loop = None

owner_loop_running = bool(owner_loop and owner_loop.is_running())
owner_loop_available = bool(
owner_loop and not owner_loop.is_closed() and owner_loop_running
)

try:
if owner_loop and owner_loop is not current_loop:
if not owner_loop_available:
logger.warning(
"Owner loop unavailable before completion callback for %s; "
"running callback in current loop",
semantic_msg_id,
)
else:
try:
fut = asyncio.run_coroutine_threadsafe(
self._execute_callback(on_complete),
owner_loop,
)
except RuntimeError:
logger.warning(
"Owner loop stopped before completion callback for %s; "
"running callback in current loop",
semantic_msg_id,
)
else:
await asyncio.wrap_future(fut)
return

await self._execute_callback(on_complete)
except Exception as e:
logger.error(
f"Error in completion callback for {semantic_msg_id}: {e}",
exc_info=True,
)

@classmethod
def get_instance(cls) -> "EmbeddingTaskTracker":
"""Get the singleton instance of EmbeddingTaskTracker."""
Expand All @@ -55,35 +135,38 @@ async def register(
on_complete: Optional callback when all tasks complete
metadata: Optional metadata to store with the task
"""
async with self._lock:
self._tasks[semantic_msg_id] = {
"remaining": total_count,
"total": total_count,
"on_complete": on_complete,
"metadata": metadata or {},
}
owner_loop = asyncio.get_running_loop()
record_to_finalize: Optional[_EmbeddingTaskRecord] = None

with self._lock:
existing = self._tasks.get(semantic_msg_id)
if existing is not None:
logger.warning(
"Overwriting existing embedding tracker record for SemanticMsg %s",
semantic_msg_id,
)

self._tasks[semantic_msg_id] = _EmbeddingTaskRecord(
remaining=total_count,
total=total_count,
on_complete=on_complete,
metadata=metadata or {},
owner_loop=owner_loop,
)
logger.info(
f"Registered embedding tracker for SemanticMsg {semantic_msg_id}: "
f"{total_count} tasks"
)

if total_count <= 0 and on_complete:
del self._tasks[semantic_msg_id]
if total_count <= 0:
record_to_finalize = self._tasks.pop(semantic_msg_id)
logger.info(
f"No embedding tasks for SemanticMsg {semantic_msg_id}, "
f"triggering on_complete immediately"
f"clearing tracker entry immediately"
)

if total_count <= 0 and on_complete:
try:
result = on_complete()
if asyncio.iscoroutine(result):
await result
except Exception as e:
logger.error(
f"Error in completion callback for {semantic_msg_id}: {e}",
exc_info=True,
)
if record_to_finalize is not None:
await self._run_on_complete(semantic_msg_id, record_to_finalize)

async def decrement(self, semantic_msg_id: str) -> Optional[int]:
"""Decrement the remaining task count for a SemanticMsg.
Expand All @@ -98,32 +181,22 @@ async def decrement(self, semantic_msg_id: str) -> Optional[int]:
Returns:
The remaining count after decrement, or None if not found
"""
on_complete = None
record_to_finalize: Optional[_EmbeddingTaskRecord] = None

async with self._lock:
if semantic_msg_id not in self._tasks:
with self._lock:
record = self._tasks.get(semantic_msg_id)
if record is None:
return None

task_info = self._tasks[semantic_msg_id]
task_info["remaining"] -= 1
remaining = task_info["remaining"]
record.remaining -= 1
remaining = record.remaining

if remaining <= 0:
on_complete = task_info.get("on_complete")

del self._tasks[semantic_msg_id]
record_to_finalize = self._tasks.pop(semantic_msg_id)
logger.info(
f"All embedding tasks({task_info['total']}) completed for SemanticMsg {semantic_msg_id}"
f"All embedding tasks({record.total}) completed for SemanticMsg {semantic_msg_id}"
)

if on_complete:
try:
result = on_complete()
if asyncio.iscoroutine(result):
await result
except Exception as e:
logger.error(
f"Error in completion callback for {semantic_msg_id}: {e}",
exc_info=True,
)
if record_to_finalize is not None:
await self._run_on_complete(semantic_msg_id, record_to_finalize)
return remaining
Loading
Loading