Skip to content

Commit c9e3590

Browse files
committed
attempted fix
1 parent 0155e77 commit c9e3590

3 files changed

Lines changed: 27 additions & 28 deletions

File tree

eval_protocol/pytest/evaluation_test.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,8 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
449449
finally:
450450
if output_buffer:
451451
await output_buffer.close()
452+
await rollout_processor.acleanup()
453+
rollout_processor.cleanup()
452454

453455
for res in priority_results:
454456
run_idx = (res.execution_metadata.extra or {}).get("run_index", 0)
@@ -697,15 +699,19 @@ async def _collect_result(config, lst):
697699
# Lazy import (cached after first import above)
698700
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
699701

700-
if isinstance(rollout_processor, MCPGymRolloutProcessor):
701-
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
702-
for run_idx in range(num_runs):
703-
task = asyncio.create_task(execute_run(run_idx, config))
704-
await task
705-
else:
706-
# For other processors, create all tasks at once and run in parallel
707-
# Concurrency is now controlled by the shared semaphore in each rollout processor
708-
await run_tasks_with_run_progress(execute_run, num_runs, config)
702+
try:
703+
if isinstance(rollout_processor, MCPGymRolloutProcessor):
704+
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
705+
for run_idx in range(num_runs):
706+
task = asyncio.create_task(execute_run(run_idx, config))
707+
await task
708+
else:
709+
# For other processors, create all tasks at once and run in parallel
710+
# Concurrency is now controlled by the shared semaphore in each rollout processor
711+
await run_tasks_with_run_progress(execute_run, num_runs, config)
712+
finally:
713+
await rollout_processor.acleanup()
714+
rollout_processor.cleanup()
709715

710716
experiment_duration_seconds = time.perf_counter() - experiment_start_time
711717

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,12 @@ async def execute_row_with_backoff_and_log(
476476
yield result
477477

478478
finally:
479-
await rollout_processor.acleanup()
480-
rollout_processor.cleanup()
479+
# Cleanup is intentionally NOT called here. rollout_processor_with_retry
480+
# is invoked per-run, but the processor (and its session) is shared
481+
# across parallel runs. Closing per-run would kill in-flight requests
482+
# in other runs. Cleanup is called once after all runs complete in
483+
# evaluation_test.py.
484+
pass
481485

482486

483487
def sanitize_filename(text: str) -> str:

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,13 @@ def __init__(
4949
self._timeout_seconds = timeout_seconds
5050
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
5151
self._session: Optional[aiohttp.ClientSession] = None
52-
self._active_runs = 0
5352

5453
def _get_or_create_session(self) -> aiohttp.ClientSession:
5554
if self._session is None or self._session.closed:
5655
self._session = aiohttp.ClientSession()
5756
return self._session
5857

5958
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
60-
self._active_runs += 1
6159
tasks: List[asyncio.Task[EvaluationRow]] = []
6260

6361
# Start with constructor values
@@ -208,26 +206,17 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
208206
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
209207
return tasks
210208

211-
def _should_close_session(self) -> bool:
212-
self._active_runs = max(0, self._active_runs - 1)
213-
return self._active_runs == 0 and self._session is not None and not self._session.closed
214-
215209
async def acleanup(self) -> None:
216-
"""Async cleanup — only closes the session when the last run finishes.
217-
218-
rollout_processor_with_retry calls acleanup() per-run, but the session
219-
is shared across parallel runs. Closing it early would cancel in-flight
220-
requests in other runs.
221-
"""
222-
if self._should_close_session():
223-
await self._session.close() # type: ignore[union-attr]
210+
"""Async cleanup - preferred when you can await."""
211+
if self._session and not self._session.closed:
212+
await self._session.close()
224213

225214
def cleanup(self) -> None:
226-
"""Sync cleanup best-effort fallback when not in an async context."""
227-
if self._should_close_session():
215+
"""Sync cleanup - best-effort, schedules close if event loop is running."""
216+
if self._session and not self._session.closed:
228217
try:
229218
loop = asyncio.get_running_loop()
230-
loop.create_task(self._session.close()) # type: ignore[union-attr]
219+
loop.create_task(self._session.close())
231220
except RuntimeError:
232221
logger.warning(
233222
"RemoteRolloutProcessor.cleanup() called outside of async context. "

0 commit comments

Comments
 (0)