Skip to content

Commit 114ad0e

Browse files
abrichrclaude
andauthored
fix: use training-appropriate evaluate timeouts instead of reordering eval (#246)
Reverts the evaluate_dense reordering from #245 (local-first was too aggressive — skipped binary eval entirely, losing the signal when 5050 IS available). The actual fix: set evaluate_timeout=15s and evaluate_retries=1 on the WAALiveAdapter in the TRL wrapper. The evaluate_dense logic stays correct (try binary first, local fallback, take max). Training speed comes from fast failure, not from skipping evaluation paths. - Benchmarking: 180s timeout, 3 retries (thorough, one-shot) - Training: 15s timeout, 1 retry (fast feedback, thousands of evals) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0922b0a commit 114ad0e

3 files changed

Lines changed: 58 additions & 51 deletions

File tree

openadapt_evals/adapters/rl_env.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -602,18 +602,27 @@ def evaluate_dense(self) -> float:
602602
if total > 0:
603603
milestone_score = passed / total
604604

605-
# Try LOCAL evaluation FIRST (fast, ~5s) when we have
606-
# task config checks. This avoids the 9+ minute timeout
607-
# when the /evaluate endpoint (port 5050) is unresponsive.
608-
# Only fall back to binary evaluate() if local eval fails
609-
# or no local checks are defined.
610-
binary_score = 0.0
611-
server_url = getattr(
612-
getattr(self._adapter, "config", None),
613-
"server_url", "",
614-
) or ""
615-
616-
if self._task_config.checks and screenshot:
605+
# Try binary evaluation (remote /evaluate endpoint).
606+
# Speed depends on adapter's evaluate_timeout config:
607+
# - Benchmarking: 180s timeout, 3 retries (thorough)
608+
# - Training: 15s timeout, 1 retry (fast feedback)
609+
# The TRL wrapper sets training-appropriate timeouts.
610+
try:
611+
binary_score = self.evaluate()
612+
except Exception:
613+
binary_score = 0.0
614+
615+
# If binary eval returned 0.0 (endpoint down or task
616+
# failed), try local evaluation via task config checks.
617+
if (
618+
binary_score == 0.0
619+
and self._task_config.checks
620+
and screenshot
621+
):
622+
server_url = getattr(
623+
getattr(self._adapter, "config", None),
624+
"server_url", "",
625+
) or ""
617626
try:
618627
binary_score = (
619628
self._task_config.evaluate_checks_local(
@@ -622,25 +631,15 @@ def evaluate_dense(self) -> float:
622631
)
623632
if binary_score > 0:
624633
logger.info(
625-
"evaluate_dense: local checks returned %.2f "
626-
"(skipping slow /evaluate endpoint)",
627-
binary_score,
634+
"evaluate_dense: local check fallback "
635+
"returned %.2f", binary_score,
628636
)
629637
except Exception as exc:
630638
logger.debug(
631-
"evaluate_dense: local check failed: %s", exc,
639+
"evaluate_dense: local check fallback "
640+
"failed: %s", exc,
632641
)
633642

634-
# Only try the slow /evaluate endpoint if local eval
635-
# returned 0.0 AND no local checks were available.
636-
# This is the path that causes 9+ min timeouts when
637-
# port 5050 is down.
638-
if binary_score == 0.0 and not self._task_config.checks:
639-
try:
640-
binary_score = self.evaluate()
641-
except Exception:
642-
binary_score = 0.0
643-
644643
# Use the higher of milestone score and binary score
645644
score = max(milestone_score, binary_score)
646645

openadapt_evals/training/trl_wrapper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ def train(self) -> str:
152152
adapter = WAALiveAdapter(WAALiveConfig(
153153
server_url=self._config.server_url,
154154
evaluate_url=getattr(self._config, "evaluate_url", None),
155+
# Training-appropriate timeouts: fail fast, don't block the
156+
# training loop. Benchmark defaults (180s, 3 retries) are for
157+
# one-shot evaluation where thoroughness matters. Training does
158+
# thousands of evaluations where speed matters.
159+
evaluate_timeout=15.0,
160+
evaluate_retries=1,
155161
))
156162
rollout_func = make_waa_rollout_func(
157163
adapter=adapter,

tests/test_dense_rewards.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -224,16 +224,16 @@ def test_reset_uses_task_config_for_task_loading(self):
224224
assert env._current_task.task_id == "test-001"
225225

226226

227-
class TestEvaluateDenseLocalFirst:
228-
"""Verify evaluate_dense tries local checks BEFORE the slow /evaluate endpoint.
227+
class TestEvaluateDenseEvalOrder:
228+
"""Verify evaluate_dense tries binary first, local fallback second.
229229
230-
This is critical for training performance: the /evaluate endpoint on port
231-
5050 can timeout for 9+ minutes (180s × 3 retries), while local checks
232-
take ~5 seconds. The evaluate_dense path must try local first.
230+
Training speed comes from the adapter's timeout config (15s for training
231+
vs 180s for benchmarking), NOT from skipping the binary eval path.
232+
Both evaluation methods are tried, and the max score is used.
233233
"""
234234

235-
def test_local_eval_before_binary_when_checks_defined(self):
236-
"""When task has checks, local eval runs first and binary is skipped."""
235+
def test_binary_eval_called_first(self):
236+
"""Binary evaluate() is always called when milestones exist."""
237237
adapter = _make_adapter()
238238
check = TaskCheck(check="command", run="echo 1", expect="1", match="exact")
239239
task_config = _make_task_config(
@@ -244,36 +244,37 @@ def test_local_eval_before_binary_when_checks_defined(self):
244244
env = RLEnvironment(adapter, task_config=task_config)
245245
env.reset(config=ResetConfig(task_id="test-001"))
246246

247-
with patch.object(task_config, "evaluate_checks_local", return_value=1.0) as mock_local:
248-
score = env.evaluate_dense()
247+
with patch.object(task_config, "evaluate_checks_local", return_value=1.0):
248+
env.evaluate_dense()
249249

250-
mock_local.assert_called_once()
251-
adapter.evaluate.assert_not_called()
252-
assert score >= 1.0
250+
# Binary eval was called (returns 0.0 from mock default)
251+
adapter.evaluate.assert_called_once()
253252

254-
def test_binary_eval_used_when_no_checks(self):
255-
"""When task has no checks, falls through to binary evaluate."""
253+
def test_local_fallback_when_binary_returns_zero(self):
254+
"""Local checks run as fallback when binary returns 0.0."""
256255
adapter = _make_adapter()
257-
adapter.evaluate.return_value = BenchmarkResult(
258-
task_id="test-001", success=True, score=0.75,
259-
)
260256
check = TaskCheck(check="command", run="echo 1", expect="1", match="exact")
261257
task_config = _make_task_config(
262258
milestones=[Milestone(name="Step done", check=check)],
263259
)
264-
# No checks — must fall through to binary
260+
task_config.checks = [check]
265261

266262
env = RLEnvironment(adapter, task_config=task_config)
267263
env.reset(config=ResetConfig(task_id="test-001"))
268264

269-
score = env.evaluate_dense()
265+
with patch.object(task_config, "evaluate_checks_local", return_value=1.0) as mock_local:
266+
score = env.evaluate_dense()
270267

271-
adapter.evaluate.assert_called_once()
268+
mock_local.assert_called_once()
269+
assert score >= 1.0
272270

273-
def test_local_eval_failure_does_not_call_binary(self):
274-
"""When local eval returns 0.0, binary is still skipped if checks exist."""
271+
def test_local_not_called_when_binary_succeeds(self):
272+
"""Local checks are skipped when binary eval returns > 0."""
275273
adapter = _make_adapter()
276-
check = TaskCheck(check="command", run="echo 0", expect="1", match="exact")
274+
adapter.evaluate.return_value = BenchmarkResult(
275+
task_id="test-001", success=True, score=0.75,
276+
)
277+
check = TaskCheck(check="command", run="echo 1", expect="1", match="exact")
277278
task_config = _make_task_config(
278279
milestones=[Milestone(name="Step done", check=check)],
279280
)
@@ -282,7 +283,8 @@ def test_local_eval_failure_does_not_call_binary(self):
282283
env = RLEnvironment(adapter, task_config=task_config)
283284
env.reset(config=ResetConfig(task_id="test-001"))
284285

285-
with patch.object(task_config, "evaluate_checks_local", return_value=0.0):
286+
with patch.object(task_config, "evaluate_checks_local") as mock_local:
286287
score = env.evaluate_dense()
287288

288-
adapter.evaluate.assert_not_called()
289+
mock_local.assert_not_called()
290+
assert score >= 0.75

0 commit comments

Comments
 (0)