@@ -224,16 +224,19 @@ def mark_ingested(
224224# ---------------------------------------------------------------------------
225225# OpenViking import
226226# ---------------------------------------------------------------------------
227- def _parse_token_usage (commit_result : Dict [str , Any ]) -> Dict [str , int ]:
228- """解析Token使用数据(从commit返回的telemetry中提取)"""
229- telemetry = commit_result .get ("telemetry" , {}).get ("summary" , {})
230- tokens = telemetry .get ("tokens" , {})
227+ def _parse_token_usage (task_result : Dict [str , Any ]) -> Dict [str , int ]:
228+ """解析Token使用数据(从get_task返回的result中提取)"""
229+ result_data = task_result .get ("result" , {})
230+ token_usage = result_data .get ("token_usage" , {})
231+ llm_tokens = token_usage .get ("llm" , {})
232+ embedding_tokens = token_usage .get ("embedding" , {})
233+ total_tokens = token_usage .get ("total" , {})
231234 return {
232- "embedding" : tokens .get ("embedding" , {}). get ( "total " , 0 ),
233- "vlm" : tokens .get ("llm" , {}). get ( "total " , 0 ),
234- "llm_input" : tokens .get ("llm" , {}). get ( "input " , 0 ),
235- "llm_output" : tokens .get ("llm" , {}). get ( "output " , 0 ),
236- "total" : tokens .get ("total " , 0 )
235+ "embedding" : embedding_tokens .get ("total_tokens " , 0 ),
236+ "vlm" : llm_tokens .get ("total_tokens " , 0 ),
237+ "llm_input" : llm_tokens .get ("prompt_tokens " , 0 ),
238+ "llm_output" : llm_tokens .get ("completion_tokens " , 0 ),
239+ "total" : total_tokens .get ("total_tokens " , 0 )
237240 }
238241
239242
@@ -287,13 +290,32 @@ async def viking_ingest(
287290 )
288291
289292 # Commit
290- result = await client .commit_session (session_id , telemetry = True )
291-
292- if result .get ("status" ) != "committed" :
293- raise RuntimeError (f"Commit failed: { result } " )
294-
295- # 直接从commit结果中提取token使用情况
296- token_usage = _parse_token_usage (result )
293+ commit_result = await client .commit_session (session_id , telemetry = True )
294+
295+ if commit_result .get ("status" ) != "accepted" :
296+ raise RuntimeError (f"Commit failed: { commit_result } " )
297+
298+ # 获取异步任务ID并轮询任务完成状态
299+ task_id = commit_result .get ("task_id" )
300+ if not task_id :
301+ raise RuntimeError (f"No task_id in commit result: { commit_result } " )
302+
303+ # 轮询任务状态直到完成
304+ max_attempts = 1200 # 最多等待20分钟
305+ for attempt in range (max_attempts ):
306+ task_result = await client .get_task (task_id )
307+ task_status = task_result .get ("status" )
308+ if task_status == "completed" :
309+ break
310+ elif task_status in ("failed" , "cancelled" ):
311+ raise RuntimeError (f"Task { task_id } { task_status } : { task_result .get ('error' )} " )
312+ # 等待1秒后重试
313+ await asyncio .sleep (1 )
314+ else :
315+ raise RuntimeError (f"Task { task_id } timed out after { max_attempts } attempts" )
316+
317+ # 从任务结果中提取token使用情况
318+ token_usage = _parse_token_usage (task_result )
297319
298320 return token_usage
299321
@@ -306,7 +328,6 @@ def sync_viking_ingest(messages: List[Dict[str, Any]], openviking_url: str, sess
306328 semaphore = asyncio .Semaphore (1 ) # 同步调用时使用信号量为1
307329 return asyncio .run (viking_ingest (messages , openviking_url , semaphore , session_time ))
308330
309-
310331# ---------------------------------------------------------------------------
311332# Main import logic
312333# ---------------------------------------------------------------------------
@@ -602,4 +623,4 @@ def main():
602623
603624
604625if __name__ == "__main__" :
605- main ()
626+ main ()
0 commit comments