Skip to content

Commit 3d2037a

Browse files
authored
fix(eval) Fix import async (#1203)
* import async * import async
1 parent 89616a6 commit 3d2037a

1 file changed

Lines changed: 39 additions & 18 deletions

File tree

benchmark/locomo/vikingbot/import_to_ov.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -224,16 +224,19 @@ def mark_ingested(
224224
# ---------------------------------------------------------------------------
225225
# OpenViking import
226226
# ---------------------------------------------------------------------------
227-
def _parse_token_usage(commit_result: Dict[str, Any]) -> Dict[str, int]:
228-
"""解析Token使用数据(从commit返回的telemetry中提取)"""
229-
telemetry = commit_result.get("telemetry", {}).get("summary", {})
230-
tokens = telemetry.get("tokens", {})
227+
def _parse_token_usage(task_result: Dict[str, Any]) -> Dict[str, int]:
228+
"""解析Token使用数据(从get_task返回的result中提取)"""
229+
result_data = task_result.get("result", {})
230+
token_usage = result_data.get("token_usage", {})
231+
llm_tokens = token_usage.get("llm", {})
232+
embedding_tokens = token_usage.get("embedding", {})
233+
total_tokens = token_usage.get("total", {})
231234
return {
232-
"embedding": tokens.get("embedding", {}).get("total", 0),
233-
"vlm": tokens.get("llm", {}).get("total", 0),
234-
"llm_input": tokens.get("llm", {}).get("input", 0),
235-
"llm_output": tokens.get("llm", {}).get("output", 0),
236-
"total": tokens.get("total", 0)
235+
"embedding": embedding_tokens.get("total_tokens", 0),
236+
"vlm": llm_tokens.get("total_tokens", 0),
237+
"llm_input": llm_tokens.get("prompt_tokens", 0),
238+
"llm_output": llm_tokens.get("completion_tokens", 0),
239+
"total": total_tokens.get("total_tokens", 0)
237240
}
238241

239242

@@ -287,13 +290,32 @@ async def viking_ingest(
287290
)
288291

289292
# Commit
290-
result = await client.commit_session(session_id, telemetry=True)
291-
292-
if result.get("status") != "committed":
293-
raise RuntimeError(f"Commit failed: {result}")
294-
295-
# 直接从commit结果中提取token使用情况
296-
token_usage = _parse_token_usage(result)
293+
commit_result = await client.commit_session(session_id, telemetry=True)
294+
295+
if commit_result.get("status") != "accepted":
296+
raise RuntimeError(f"Commit failed: {commit_result}")
297+
298+
# 获取异步任务ID并轮询任务完成状态
299+
task_id = commit_result.get("task_id")
300+
if not task_id:
301+
raise RuntimeError(f"No task_id in commit result: {commit_result}")
302+
303+
# 轮询任务状态直到完成
304+
max_attempts = 1200 # 最多等待20分钟
305+
for attempt in range(max_attempts):
306+
task_result = await client.get_task(task_id)
307+
task_status = task_result.get("status")
308+
if task_status == "completed":
309+
break
310+
elif task_status in ("failed", "cancelled"):
311+
raise RuntimeError(f"Task {task_id} {task_status}: {task_result.get('error')}")
312+
# 等待1秒后重试
313+
await asyncio.sleep(1)
314+
else:
315+
raise RuntimeError(f"Task {task_id} timed out after {max_attempts} attempts")
316+
317+
# 从任务结果中提取token使用情况
318+
token_usage = _parse_token_usage(task_result)
297319

298320
return token_usage
299321

@@ -306,7 +328,6 @@ def sync_viking_ingest(messages: List[Dict[str, Any]], openviking_url: str, sess
306328
semaphore = asyncio.Semaphore(1) # 同步调用时使用信号量为1
307329
return asyncio.run(viking_ingest(messages, openviking_url, semaphore, session_time))
308330

309-
310331
# ---------------------------------------------------------------------------
311332
# Main import logic
312333
# ---------------------------------------------------------------------------
@@ -602,4 +623,4 @@ def main():
602623

603624

604625
if __name__ == "__main__":
605-
main()
626+
main()

0 commit comments

Comments
 (0)