33"""Embedding Task Tracker for tracking embedding task completion status."""
44
55import asyncio
6+ import inspect
7+ import threading
8+ from dataclasses import dataclass
69from typing import Any , Callable , Dict , Optional
710
811from openviking_cli .utils .logger import get_logger
912
1013logger = get_logger (__name__ )
1114
1215
16+ @dataclass
17+ class _EmbeddingTaskRecord :
18+ """Coordinator state for a single semantic message."""
19+
20+ remaining : int
21+ total : int
22+ on_complete : Optional [Callable [[], Any ]]
23+ metadata : Dict [str , Any ]
24+ owner_loop : Optional [asyncio .AbstractEventLoop ]
25+
26+
1327class EmbeddingTaskTracker :
1428 """Track embedding task completion status for each SemanticMsg.
1529
16- This tracker maintains a global registry of embedding tasks associated
17- with each SemanticMsg. When all embedding tasks for a SemanticMsg are
18- completed, it triggers the registered callback and removes the entry.
30+ This tracker maintains a process-global registry of embedding tasks associated
31+ with each SemanticMsg. Because semantic and embedding queues run on separate
32+ worker threads with distinct event loops, its internal state must be guarded
33+ by thread-safe primitives rather than loop-bound asyncio locks.
34+
35+ When all embedding tasks for a SemanticMsg are completed, it triggers the
36+ registered callback and removes the entry.
1937 """
2038
2139 _instance : Optional ["EmbeddingTaskTracker" ] = None
@@ -29,10 +47,72 @@ def __new__(cls) -> "EmbeddingTaskTracker":
2947 def __init__ (self ):
3048 if self ._initialized :
3149 return
32- self ._lock : asyncio . Lock = asyncio .Lock ()
33- self ._tasks : Dict [str , Dict [ str , Any ] ] = {}
50+ self ._lock = threading .Lock ()
51+ self ._tasks : Dict [str , _EmbeddingTaskRecord ] = {}
3452 self ._initialized = True
3553
54+ @staticmethod
55+ async def _await_callback_result (result : Any ) -> None :
56+ """Await callback results when they are async."""
57+ if inspect .isawaitable (result ):
58+ await result
59+
60+ async def _execute_callback (self , on_complete : Callable [[], Any ]) -> None :
61+ """Invoke a completion callback and await async results."""
62+ await self ._await_callback_result (on_complete ())
63+
64+ async def _run_on_complete (
65+ self ,
66+ semantic_msg_id : str ,
67+ record : _EmbeddingTaskRecord ,
68+ ) -> None :
69+ """Execute the completion callback on the loop that registered it."""
70+ on_complete = record .on_complete
71+ owner_loop = record .owner_loop
72+ if on_complete is None :
73+ return
74+
75+ try :
76+ current_loop = asyncio .get_running_loop ()
77+ except RuntimeError :
78+ current_loop = None
79+
80+ owner_loop_running = bool (owner_loop and owner_loop .is_running ())
81+ owner_loop_available = bool (
82+ owner_loop and not owner_loop .is_closed () and owner_loop_running
83+ )
84+
85+ try :
86+ if owner_loop and owner_loop is not current_loop :
87+ if not owner_loop_available :
88+ logger .warning (
89+ "Owner loop unavailable before completion callback for %s; "
90+ "running callback in current loop" ,
91+ semantic_msg_id ,
92+ )
93+ else :
94+ try :
95+ fut = asyncio .run_coroutine_threadsafe (
96+ self ._execute_callback (on_complete ),
97+ owner_loop ,
98+ )
99+ except RuntimeError :
100+ logger .warning (
101+ "Owner loop stopped before completion callback for %s; "
102+ "running callback in current loop" ,
103+ semantic_msg_id ,
104+ )
105+ else :
106+ await asyncio .wrap_future (fut )
107+ return
108+
109+ await self ._execute_callback (on_complete )
110+ except Exception as e :
111+ logger .error (
112+ f"Error in completion callback for { semantic_msg_id } : { e } " ,
113+ exc_info = True ,
114+ )
115+
36116 @classmethod
37117 def get_instance (cls ) -> "EmbeddingTaskTracker" :
38118 """Get the singleton instance of EmbeddingTaskTracker."""
@@ -55,35 +135,38 @@ async def register(
55135 on_complete: Optional callback when all tasks complete
56136 metadata: Optional metadata to store with the task
57137 """
58- async with self ._lock :
59- self ._tasks [semantic_msg_id ] = {
60- "remaining" : total_count ,
61- "total" : total_count ,
62- "on_complete" : on_complete ,
63- "metadata" : metadata or {},
64- }
138+ owner_loop = asyncio .get_running_loop ()
139+ record_to_finalize : Optional [_EmbeddingTaskRecord ] = None
140+
141+ with self ._lock :
142+ existing = self ._tasks .get (semantic_msg_id )
143+ if existing is not None :
144+ logger .warning (
145+ "Overwriting existing embedding tracker record for SemanticMsg %s" ,
146+ semantic_msg_id ,
147+ )
148+
149+ self ._tasks [semantic_msg_id ] = _EmbeddingTaskRecord (
150+ remaining = total_count ,
151+ total = total_count ,
152+ on_complete = on_complete ,
153+ metadata = metadata or {},
154+ owner_loop = owner_loop ,
155+ )
65156 logger .info (
66157 f"Registered embedding tracker for SemanticMsg { semantic_msg_id } : "
67158 f"{ total_count } tasks"
68159 )
69160
70- if total_count <= 0 and on_complete :
71- del self ._tasks [ semantic_msg_id ]
161+ if total_count <= 0 :
162+ record_to_finalize = self ._tasks . pop ( semantic_msg_id )
72163 logger .info (
73164 f"No embedding tasks for SemanticMsg { semantic_msg_id } , "
74- f"triggering on_complete immediately"
165+ f"clearing tracker entry immediately"
75166 )
76167
77- if total_count <= 0 and on_complete :
78- try :
79- result = on_complete ()
80- if asyncio .iscoroutine (result ):
81- await result
82- except Exception as e :
83- logger .error (
84- f"Error in completion callback for { semantic_msg_id } : { e } " ,
85- exc_info = True ,
86- )
168+ if record_to_finalize is not None :
169+ await self ._run_on_complete (semantic_msg_id , record_to_finalize )
87170
88171 async def decrement (self , semantic_msg_id : str ) -> Optional [int ]:
89172 """Decrement the remaining task count for a SemanticMsg.
@@ -98,32 +181,22 @@ async def decrement(self, semantic_msg_id: str) -> Optional[int]:
98181 Returns:
99182 The remaining count after decrement, or None if not found
100183 """
101- on_complete = None
184+ record_to_finalize : Optional [ _EmbeddingTaskRecord ] = None
102185
103- async with self ._lock :
104- if semantic_msg_id not in self ._tasks :
186+ with self ._lock :
187+ record = self ._tasks .get (semantic_msg_id )
188+ if record is None :
105189 return None
106190
107- task_info = self ._tasks [semantic_msg_id ]
108- task_info ["remaining" ] -= 1
109- remaining = task_info ["remaining" ]
191+ record .remaining -= 1
192+ remaining = record .remaining
110193
111194 if remaining <= 0 :
112- on_complete = task_info .get ("on_complete" )
113-
114- del self ._tasks [semantic_msg_id ]
195+ record_to_finalize = self ._tasks .pop (semantic_msg_id )
115196 logger .info (
116- f"All embedding tasks({ task_info [ ' total' ] } ) completed for SemanticMsg { semantic_msg_id } "
197+ f"All embedding tasks({ record . total } ) completed for SemanticMsg { semantic_msg_id } "
117198 )
118199
119- if on_complete :
120- try :
121- result = on_complete ()
122- if asyncio .iscoroutine (result ):
123- await result
124- except Exception as e :
125- logger .error (
126- f"Error in completion callback for { semantic_msg_id } : { e } " ,
127- exc_info = True ,
128- )
200+ if record_to_finalize is not None :
201+ await self ._run_on_complete (semantic_msg_id , record_to_finalize )
129202 return remaining
0 commit comments