Skip to content

Commit e71ed9f

Browse files
abrichrclaude
andauthored
fix: add truncation warning to TRL generate paths (#242)
Add a truncation check after both generation paths (Outlines constrained and HF unconstrained) in generate_fn. When the output length reaches max_new_tokens - 1, a warning is logged suggesting to increase max_new_tokens or enable constrained_decoding. This helps diagnose cases where the model generates excessively long reasoning that gets cut off before producing a parseable action. Also replaced the tautological truncation tests in test_trl_robustness.py (which reimplemented the check logic inline) with tests that exercise the actual generate_fn code path by calling it through the rollout function with mocked torch and model.generate. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6a38956 commit e71ed9f

2 files changed

Lines changed: 155 additions & 64 deletions

File tree

openadapt_evals/training/trl_rollout.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,16 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
533533
# return empty logprobs. TRL recomputes logprobs from
534534
# the model during the training step anyway.
535535
logprobs: list[float] = []
536+
537+
# Truncation warning — detect when output was cut off
538+
if len(completion_ids) >= max_new_tokens - 1:
539+
logger.warning(
540+
"Generation hit max_new_tokens=%d. Output may be "
541+
"truncated. If actions are unparseable, increase "
542+
"max_new_tokens or enable constrained_decoding.",
543+
max_new_tokens,
544+
)
545+
536546
return decoded, completion_ids, logprobs
537547

538548
# --- Standard HF generate path (unconstrained) ---
@@ -567,6 +577,15 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
567577
# Decode text
568578
text = processor.decode(completion_ids, skip_special_tokens=True)
569579

580+
# Truncation warning — detect when output was cut off
581+
if len(completion_ids) >= max_new_tokens - 1:
582+
logger.warning(
583+
"Generation hit max_new_tokens=%d. Output may be "
584+
"truncated. If actions are unparseable, increase "
585+
"max_new_tokens or enable constrained_decoding.",
586+
max_new_tokens,
587+
)
588+
570589
return text, completion_ids, logprobs
571590

572591
for prompt in prompts:

tests/test_trl_robustness.py

Lines changed: 136 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -421,96 +421,168 @@ def fake_generate(screenshot_bytes, instruction):
421421

422422

423423
class TestTruncationWarning:
424-
"""Tests for truncation warning in generate_fn."""
424+
"""Tests for truncation warning in generate_fn.
425+
426+
These tests exercise the ACTUAL truncation check in generate_fn by
427+
calling it through _run_episode. The mock_run calls the real gfn()
428+
(the closure created by make_waa_rollout_func) with mocked torch and
429+
PIL so the truncation check in trl_rollout.py is exercised — not
430+
reimplemented in the test.
431+
"""
432+
433+
def test_truncation_warning_logged_hf_path(self, caplog):
434+
"""HF generate path: output hitting max_new_tokens triggers warning.
435+
436+
We intercept _run_episode, call gfn (the real generate_fn),
437+
and mock model.generate to return max_new_tokens tokens so the
438+
truncation check fires.
439+
"""
440+
import io as _io
441+
from PIL import Image
442+
from openadapt_evals.training import trl_rollout
425443

426-
def test_truncation_warning_logged(self, caplog):
427-
"""Output hitting max_new_tokens without 'done' triggers warning."""
428444
adapter = _make_mock_adapter()
445+
max_new_tokens = 10
429446
func = make_waa_rollout_func(
430447
adapter,
431-
max_steps=3,
432-
max_new_tokens=10,
448+
max_steps=1,
449+
max_new_tokens=max_new_tokens,
433450
screenshot_retries=1,
434451
screenshot_retry_delay=0,
435452
)
436453
trainer = _make_mock_trainer(num_generations=1)
437454

438-
from openadapt_evals.training import trl_rollout
455+
# Create a real small PNG so PIL.Image.open succeeds inside generate_fn
456+
img = Image.new("RGB", (10, 10), color="red")
457+
buf = _io.BytesIO()
458+
img.save(buf, format="PNG")
459+
valid_png = buf.getvalue()
460+
461+
# Make adapter return a real PNG for the health check
462+
adapter.observe.return_value = BenchmarkObservation(
463+
screenshot=valid_png, raw_observation={},
464+
)
465+
466+
# Build a mock torch module with the pieces generate_fn needs
467+
mock_torch = MagicMock()
468+
mock_torch.no_grad.return_value.__enter__ = MagicMock(return_value=None)
469+
mock_torch.no_grad.return_value.__exit__ = MagicMock(return_value=False)
470+
471+
# model.generate returns prompt(5) + completion(10) tokens
472+
prompt_len = 5
473+
fake_completion = list(range(max_new_tokens)) # 10 tokens >= 10-1
474+
fake_seq = MagicMock()
475+
fake_seq.__getitem__ = lambda self, idx: MagicMock(
476+
tolist=lambda: fake_completion,
477+
)
478+
mock_gen_output = MagicMock()
479+
mock_gen_output.sequences = [fake_seq]
480+
mock_gen_output.scores = []
481+
trainer.model.generate.return_value = mock_gen_output
482+
483+
# Mock processor to return inputs with known prompt length
484+
mock_inputs = MagicMock()
485+
mock_input_ids = MagicMock()
486+
mock_input_ids.shape = [1, prompt_len]
487+
mock_inputs.__getitem__ = lambda self, key: (
488+
mock_input_ids if key == "input_ids" else MagicMock()
489+
)
490+
mock_inputs.to.return_value = mock_inputs
491+
trainer.processing_class.return_value = mock_inputs
492+
trainer.processing_class.apply_chat_template.return_value = "fake prompt"
493+
trainer.processing_class.decode.return_value = '{"type": "done"}'
439494

440495
def mock_run(env, gfn, instr, tid, ms, stuck_threshold=3):
496+
"""Call the REAL generate_fn with mocked torch."""
441497
from openadapt_evals.adapters.rl_env import ResetConfig
442498
env.reset(config=ResetConfig(task_id=tid))
443499

444-
# Mock the generate path to simulate truncation
445-
from PIL import Image
446-
import io as _io
447-
448-
# Create a real small PNG for PIL to open
449-
img = Image.new("RGB", (10, 10), color="red")
450-
buf = _io.BytesIO()
451-
img.save(buf, format="PNG")
452-
valid_png = buf.getvalue()
453-
454-
# Patch the model to return max_new_tokens - 1 tokens
455-
# with nonsensical text that doesn't contain "done"
456-
mock_outputs = MagicMock()
457-
mock_seqs = MagicMock()
458-
mock_seqs.__getitem__ = lambda self, idx: MagicMock(
459-
tolist=lambda: list(range(9)), # 9 tokens = max_new_tokens - 1
460-
)
461-
mock_outputs.sequences = [mock_seqs[0]]
462-
mock_outputs.scores = []
463-
464-
mock_inputs = MagicMock()
465-
mock_inputs.__getitem__ = lambda self, key: MagicMock(shape=[1, 5])
466-
467-
# Simulate truncation: call generate_fn, intercept at model level
468-
# For simplicity, we'll directly test the truncation check logic
469-
# by calling parse_action_json on truncated output
470-
text_with_no_done = "I was thinking about clicking the butt"
471-
completion_ids = list(range(9)) # 9 >= 10-1 triggers check
472-
473-
if len(completion_ids) >= 10 - 1: # max_new_tokens - 1
474-
action = parse_action_json(text_with_no_done)
475-
if action.type == "done" and "done" not in text_with_no_done.lower():
476-
import logging as _logging
477-
_logging.getLogger("openadapt_evals.training.trl_rollout").warning(
478-
"Output truncated at %d tokens without parseable "
479-
"action. Consider increasing max_new_tokens "
480-
"(current: %d) or checking VRAM.",
481-
len(completion_ids),
482-
10,
483-
)
500+
with patch.dict("sys.modules", {"torch": mock_torch}):
501+
text, ids, lps = gfn(valid_png, "test instruction")
484502

485-
return [1], [2], [-0.1], 0.0
503+
return [1], ids, lps, 0.0
486504

487505
with caplog.at_level(logging.WARNING, logger="openadapt_evals.training.trl_rollout"):
488506
with patch.object(trl_rollout, "_run_episode", side_effect=mock_run):
489507
func(["Test task"], trainer)
490508

491-
assert any("truncated" in r.message.lower() for r in caplog.records), (
492-
f"Expected truncation warning in logs, got: {[r.message for r in caplog.records]}"
509+
assert any("generation hit max_new_tokens" in r.message.lower() for r in caplog.records), (
510+
f"Expected truncation warning from generate_fn, got: "
511+
f"{[r.message for r in caplog.records]}"
493512
)
494513

495-
def test_truncation_no_warning_for_done(self, caplog):
496-
"""Output says 'done' and hits limit -- no truncation warning."""
497-
text_with_done = '{"type": "done"} I am done now'
498-
completion_ids = list(range(9))
499-
max_new_tokens = 10
514+
def test_no_truncation_warning_when_short(self, caplog):
515+
"""HF generate path: short output does NOT trigger warning."""
516+
import io as _io
517+
from PIL import Image
518+
from openadapt_evals.training import trl_rollout
519+
520+
adapter = _make_mock_adapter()
521+
max_new_tokens = 256
522+
func = make_waa_rollout_func(
523+
adapter,
524+
max_steps=1,
525+
max_new_tokens=max_new_tokens,
526+
screenshot_retries=1,
527+
screenshot_retry_delay=0,
528+
)
529+
trainer = _make_mock_trainer(num_generations=1)
530+
531+
img = Image.new("RGB", (10, 10), color="red")
532+
buf = _io.BytesIO()
533+
img.save(buf, format="PNG")
534+
valid_png = buf.getvalue()
535+
536+
adapter.observe.return_value = BenchmarkObservation(
537+
screenshot=valid_png, raw_observation={},
538+
)
500539

501-
# Simulate the truncation check
502-
if len(completion_ids) >= max_new_tokens - 1:
503-
action = parse_action_json(text_with_done)
504-
# "done" IS in text.lower(), so this should NOT fire
505-
if action.type == "done" and "done" not in text_with_done.lower():
506-
logging.getLogger("openadapt_evals.training.trl_rollout").warning(
507-
"Output truncated"
508-
)
540+
mock_torch = MagicMock()
541+
mock_torch.no_grad.return_value.__enter__ = MagicMock(return_value=None)
542+
mock_torch.no_grad.return_value.__exit__ = MagicMock(return_value=False)
543+
544+
# Short completion — only 3 tokens, well below max_new_tokens
545+
prompt_len = 5
546+
fake_completion = list(range(3))
547+
fake_seq = MagicMock()
548+
fake_seq.__getitem__ = lambda self, idx: MagicMock(
549+
tolist=lambda: fake_completion,
550+
)
551+
mock_gen_output = MagicMock()
552+
mock_gen_output.sequences = [fake_seq]
553+
mock_gen_output.scores = []
554+
trainer.model.generate.return_value = mock_gen_output
555+
556+
mock_inputs = MagicMock()
557+
mock_input_ids = MagicMock()
558+
mock_input_ids.shape = [1, prompt_len]
559+
mock_inputs.__getitem__ = lambda self, key: (
560+
mock_input_ids if key == "input_ids" else MagicMock()
561+
)
562+
mock_inputs.to.return_value = mock_inputs
563+
trainer.processing_class.return_value = mock_inputs
564+
trainer.processing_class.apply_chat_template.return_value = "fake prompt"
565+
trainer.processing_class.decode.return_value = '{"type": "done"}'
566+
567+
def mock_run(env, gfn, instr, tid, ms, stuck_threshold=3):
568+
from openadapt_evals.adapters.rl_env import ResetConfig
569+
env.reset(config=ResetConfig(task_id=tid))
570+
571+
with patch.dict("sys.modules", {"torch": mock_torch}):
572+
text, ids, lps = gfn(valid_png, "test instruction")
573+
574+
return [1], ids, lps, 0.0
575+
576+
with caplog.at_level(logging.WARNING, logger="openadapt_evals.training.trl_rollout"):
577+
with patch.object(trl_rollout, "_run_episode", side_effect=mock_run):
578+
func(["Test task"], trainer)
509579

510-
# No warning should have been logged
511580
assert not any(
512-
"truncated" in r.message.lower()
581+
"generation hit max_new_tokens" in r.message.lower()
513582
for r in caplog.records
583+
), (
584+
f"Should NOT see truncation warning for short output, got: "
585+
f"{[r.message for r in caplog.records]}"
514586
)
515587

516588

0 commit comments

Comments
 (0)