Skip to content

Commit 7eb50cb

Browse files
authored
fix(queuefs): harden embedding tracker across worker loops (#1024)
* fix(queuefs): harden embedding tracker across worker loops Replace the global asyncio lock in EmbeddingTaskTracker with thread-safe coordination so semantic and embedding workers can share completion state without cross-event-loop failures. Add regression coverage for cross-loop completion dispatch and align semantic DAG test doubles with the current vectorize API. * fix(queuefs): avoid handing callbacks to stopped owner loops Treat stopped semantic owner loops as unavailable before dispatching completion callbacks from embedding workers. Add regression coverage for the stop-before-close shutdown window so the final decrement path falls back to the current loop instead of hanging.
1 parent af163d3 commit 7eb50cb

4 files changed

Lines changed: 321 additions & 47 deletions

File tree

openviking/storage/queuefs/embedding_tracker.py

Lines changed: 118 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,37 @@
33
"""Embedding Task Tracker for tracking embedding task completion status."""
44

55
import asyncio
6+
import inspect
7+
import threading
8+
from dataclasses import dataclass
69
from typing import Any, Callable, Dict, Optional
710

811
from openviking_cli.utils.logger import get_logger
912

1013
logger = get_logger(__name__)
1114

1215

16+
@dataclass
17+
class _EmbeddingTaskRecord:
18+
"""Coordinator state for a single semantic message."""
19+
20+
remaining: int
21+
total: int
22+
on_complete: Optional[Callable[[], Any]]
23+
metadata: Dict[str, Any]
24+
owner_loop: Optional[asyncio.AbstractEventLoop]
25+
26+
1327
class EmbeddingTaskTracker:
1428
"""Track embedding task completion status for each SemanticMsg.
1529
16-
This tracker maintains a global registry of embedding tasks associated
17-
with each SemanticMsg. When all embedding tasks for a SemanticMsg are
18-
completed, it triggers the registered callback and removes the entry.
30+
This tracker maintains a process-global registry of embedding tasks associated
31+
with each SemanticMsg. Because semantic and embedding queues run on separate
32+
worker threads with distinct event loops, its internal state must be guarded
33+
by thread-safe primitives rather than loop-bound asyncio locks.
34+
35+
When all embedding tasks for a SemanticMsg are completed, it triggers the
36+
registered callback and removes the entry.
1937
"""
2038

2139
_instance: Optional["EmbeddingTaskTracker"] = None
@@ -29,10 +47,72 @@ def __new__(cls) -> "EmbeddingTaskTracker":
2947
def __init__(self):
3048
if self._initialized:
3149
return
32-
self._lock: asyncio.Lock = asyncio.Lock()
33-
self._tasks: Dict[str, Dict[str, Any]] = {}
50+
self._lock = threading.Lock()
51+
self._tasks: Dict[str, _EmbeddingTaskRecord] = {}
3452
self._initialized = True
3553

54+
@staticmethod
55+
async def _await_callback_result(result: Any) -> None:
56+
"""Await callback results when they are async."""
57+
if inspect.isawaitable(result):
58+
await result
59+
60+
async def _execute_callback(self, on_complete: Callable[[], Any]) -> None:
61+
"""Invoke a completion callback and await async results."""
62+
await self._await_callback_result(on_complete())
63+
64+
async def _run_on_complete(
65+
self,
66+
semantic_msg_id: str,
67+
record: _EmbeddingTaskRecord,
68+
) -> None:
69+
"""Execute the completion callback on the loop that registered it."""
70+
on_complete = record.on_complete
71+
owner_loop = record.owner_loop
72+
if on_complete is None:
73+
return
74+
75+
try:
76+
current_loop = asyncio.get_running_loop()
77+
except RuntimeError:
78+
current_loop = None
79+
80+
owner_loop_running = bool(owner_loop and owner_loop.is_running())
81+
owner_loop_available = bool(
82+
owner_loop and not owner_loop.is_closed() and owner_loop_running
83+
)
84+
85+
try:
86+
if owner_loop and owner_loop is not current_loop:
87+
if not owner_loop_available:
88+
logger.warning(
89+
"Owner loop unavailable before completion callback for %s; "
90+
"running callback in current loop",
91+
semantic_msg_id,
92+
)
93+
else:
94+
try:
95+
fut = asyncio.run_coroutine_threadsafe(
96+
self._execute_callback(on_complete),
97+
owner_loop,
98+
)
99+
except RuntimeError:
100+
logger.warning(
101+
"Owner loop stopped before completion callback for %s; "
102+
"running callback in current loop",
103+
semantic_msg_id,
104+
)
105+
else:
106+
await asyncio.wrap_future(fut)
107+
return
108+
109+
await self._execute_callback(on_complete)
110+
except Exception as e:
111+
logger.error(
112+
f"Error in completion callback for {semantic_msg_id}: {e}",
113+
exc_info=True,
114+
)
115+
36116
@classmethod
37117
def get_instance(cls) -> "EmbeddingTaskTracker":
38118
"""Get the singleton instance of EmbeddingTaskTracker."""
@@ -55,35 +135,38 @@ async def register(
55135
on_complete: Optional callback when all tasks complete
56136
metadata: Optional metadata to store with the task
57137
"""
58-
async with self._lock:
59-
self._tasks[semantic_msg_id] = {
60-
"remaining": total_count,
61-
"total": total_count,
62-
"on_complete": on_complete,
63-
"metadata": metadata or {},
64-
}
138+
owner_loop = asyncio.get_running_loop()
139+
record_to_finalize: Optional[_EmbeddingTaskRecord] = None
140+
141+
with self._lock:
142+
existing = self._tasks.get(semantic_msg_id)
143+
if existing is not None:
144+
logger.warning(
145+
"Overwriting existing embedding tracker record for SemanticMsg %s",
146+
semantic_msg_id,
147+
)
148+
149+
self._tasks[semantic_msg_id] = _EmbeddingTaskRecord(
150+
remaining=total_count,
151+
total=total_count,
152+
on_complete=on_complete,
153+
metadata=metadata or {},
154+
owner_loop=owner_loop,
155+
)
65156
logger.info(
66157
f"Registered embedding tracker for SemanticMsg {semantic_msg_id}: "
67158
f"{total_count} tasks"
68159
)
69160

70-
if total_count <= 0 and on_complete:
71-
del self._tasks[semantic_msg_id]
161+
if total_count <= 0:
162+
record_to_finalize = self._tasks.pop(semantic_msg_id)
72163
logger.info(
73164
f"No embedding tasks for SemanticMsg {semantic_msg_id}, "
74-
f"triggering on_complete immediately"
165+
f"clearing tracker entry immediately"
75166
)
76167

77-
if total_count <= 0 and on_complete:
78-
try:
79-
result = on_complete()
80-
if asyncio.iscoroutine(result):
81-
await result
82-
except Exception as e:
83-
logger.error(
84-
f"Error in completion callback for {semantic_msg_id}: {e}",
85-
exc_info=True,
86-
)
168+
if record_to_finalize is not None:
169+
await self._run_on_complete(semantic_msg_id, record_to_finalize)
87170

88171
async def decrement(self, semantic_msg_id: str) -> Optional[int]:
89172
"""Decrement the remaining task count for a SemanticMsg.
@@ -98,32 +181,22 @@ async def decrement(self, semantic_msg_id: str) -> Optional[int]:
98181
Returns:
99182
The remaining count after decrement, or None if not found
100183
"""
101-
on_complete = None
184+
record_to_finalize: Optional[_EmbeddingTaskRecord] = None
102185

103-
async with self._lock:
104-
if semantic_msg_id not in self._tasks:
186+
with self._lock:
187+
record = self._tasks.get(semantic_msg_id)
188+
if record is None:
105189
return None
106190

107-
task_info = self._tasks[semantic_msg_id]
108-
task_info["remaining"] -= 1
109-
remaining = task_info["remaining"]
191+
record.remaining -= 1
192+
remaining = record.remaining
110193

111194
if remaining <= 0:
112-
on_complete = task_info.get("on_complete")
113-
114-
del self._tasks[semantic_msg_id]
195+
record_to_finalize = self._tasks.pop(semantic_msg_id)
115196
logger.info(
116-
f"All embedding tasks({task_info['total']}) completed for SemanticMsg {semantic_msg_id}"
197+
f"All embedding tasks({record.total}) completed for SemanticMsg {semantic_msg_id}"
117198
)
118199

119-
if on_complete:
120-
try:
121-
result = on_complete()
122-
if asyncio.iscoroutine(result):
123-
await result
124-
except Exception as e:
125-
logger.error(
126-
f"Error in completion callback for {semantic_msg_id}: {e}",
127-
exc_info=True,
128-
)
200+
if record_to_finalize is not None:
201+
await self._run_on_complete(semantic_msg_id, record_to_finalize)
129202
return remaining

0 commit comments

Comments
 (0)