Skip to content

Commit 524fb59

Browse files
committed
fix: use async connection pooling for tracing log polls
Replace per-request requests.get() calls in search_logs with a native async aiohttp implementation that reuses the RemoteRolloutProcessor's shared ClientSession. This eliminates the SSL EOF errors seen under high concurrency by reusing TCP+TLS connections instead of opening a fresh handshake on every poll. Additional fixes: - Remove default connection pool cap (TCPConnector limit=0) so the semaphore is the only concurrency bound - Increase /init timeout from 300s to 600s - Move common_utils import to module level Made-with: Cursor
1 parent f77b26f commit 524fb59

2 files changed

Lines changed: 56 additions & 10 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
"""
66

77
from __future__ import annotations
8+
import asyncio
89
import logging
10+
import aiohttp
911
import requests
1012
from datetime import datetime
1113
from typing import Any, Dict, List, Optional, Protocol
@@ -14,6 +16,7 @@
1416
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
1517
from .base import BaseAdapter
1618
from .utils import extract_messages_from_data
19+
from ..common_utils import get_user_agent
1720

1821
logger = logging.getLogger(__name__)
1922

@@ -280,8 +283,6 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
280283
if not tags:
281284
raise ValueError("At least one tag is required to fetch logs")
282285

283-
from ..common_utils import get_user_agent
284-
285286
headers = {
286287
"Authorization": f"Bearer {self._get_api_key()}",
287288
"User-Agent": get_user_agent(),
@@ -327,6 +328,52 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
327328
)
328329
return results
329330

331+
async def async_search_logs(
332+
self, session: aiohttp.ClientSession, tags: List[str], limit: int = 100, hours_back: int = 24
333+
) -> List[Dict[str, Any]]:
334+
"""Async version of search_logs, reuses a caller-provided aiohttp session."""
335+
if not tags:
336+
raise ValueError("At least one tag is required to fetch logs")
337+
338+
params: Dict[str, Any] = {"tags": tags, "limit": limit, "hours_back": hours_back, "program": "eval_protocol"}
339+
headers = {"Authorization": f"Bearer {self._get_api_key()}", "User-Agent": get_user_agent()}
340+
timeout = aiohttp.ClientTimeout(total=self.timeout)
341+
342+
urls_to_try = [f"{self.base_url}/logs", f"{self.base_url}/v1/logs"]
343+
data: Dict[str, Any] = {}
344+
last_error: Optional[str] = None
345+
for url in urls_to_try:
346+
try:
347+
async with session.get(url, params=params, headers=headers, timeout=timeout) as resp:
348+
if resp.status == 404:
349+
last_error = f"404 for {url}"
350+
continue
351+
resp.raise_for_status()
352+
data = (await resp.json()) or {}
353+
break
354+
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
355+
last_error = str(e)
356+
continue
357+
else:
358+
if last_error:
359+
logger.error("Failed to fetch logs from Fireworks (tried %s): %s", urls_to_try, last_error)
360+
return []
361+
362+
entries: List[Dict[str, Any]] = data.get("entries", []) or []
363+
results: List[Dict[str, Any]] = []
364+
for e in entries:
365+
results.append(
366+
{
367+
"timestamp": e.get("timestamp"),
368+
"message": e.get("message"),
369+
"severity": e.get("severity", "INFO"),
370+
"tags": e.get("tags", []),
371+
"status": e.get("status"),
372+
"extras": e.get("extras"),
373+
}
374+
)
375+
return results
376+
330377
def get_evaluation_rows(
331378
self,
332379
tags: List[str],
@@ -411,8 +458,6 @@ def get_evaluation_rows(
411458
else:
412459
url = f"{self.base_url}/v1/traces/pointwise"
413460

414-
from ..common_utils import get_user_agent
415-
416461
headers = {
417462
"Authorization": f"Bearer {self._get_api_key()}",
418463
"User-Agent": get_user_agent(),

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(
5252

5353
def _get_or_create_session(self) -> aiohttp.ClientSession:
5454
if self._session is None or self._session.closed:
55-
self._session = aiohttp.ClientSession()
55+
connector = aiohttp.TCPConnector(limit=0) # no total connection limit
56+
self._session = aiohttp.ClientSession(connector=connector)
5657
return self._session
5758

5859
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
@@ -99,7 +100,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
99100
# Fire-and-poll
100101
init_url = f"{remote_base_url}/init"
101102

102-
timeout_init = aiohttp.ClientTimeout(total=300)
103+
timeout_init = aiohttp.ClientTimeout(total=600)
103104

104105
try:
105106
session = self._get_or_create_session()
@@ -114,15 +115,15 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
114115
await resp.read() # Drain the response body and release the connection back to the pool
115116
except asyncio.TimeoutError:
116117
raise TimeoutError(
117-
f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds."
118+
f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 600 seconds."
118119
)
119120

120121
deadline = time.time() + timeout_seconds
121122

122123
while time.time() < deadline:
123-
# Search Fireworks tracing logs for completion (run in thread to avoid blocking event loop)
124-
completed_logs = await asyncio.to_thread(
125-
self._tracing_adapter.search_logs, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
124+
session = self._get_or_create_session()
125+
completed_logs = await self._tracing_adapter.async_search_logs(
126+
session, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
126127
)
127128
# Filter for logs that actually have status information
128129
status_logs = []

0 commit comments

Comments
 (0)