@@ -421,96 +421,168 @@ def fake_generate(screenshot_bytes, instruction):
421421
422422
423423class TestTruncationWarning :
424- """Tests for truncation warning in generate_fn."""
424+ """Tests for truncation warning in generate_fn.
425+
426+ These tests exercise the ACTUAL truncation check in generate_fn by
427+ calling it through _run_episode. The mock_run calls the real gfn()
428+ (the closure created by make_waa_rollout_func) with mocked torch and
429+ PIL so the truncation check in trl_rollout.py is exercised — not
430+ reimplemented in the test.
431+ """
432+
433+ def test_truncation_warning_logged_hf_path (self , caplog ):
434+ """HF generate path: output hitting max_new_tokens triggers warning.
435+
436+ We intercept _run_episode, call gfn (the real generate_fn),
437+ and mock model.generate to return max_new_tokens tokens so the
438+ truncation check fires.
439+ """
440+ import io as _io
441+ from PIL import Image
442+ from openadapt_evals .training import trl_rollout
425443
426- def test_truncation_warning_logged (self , caplog ):
427- """Output hitting max_new_tokens without 'done' triggers warning."""
428444 adapter = _make_mock_adapter ()
445+ max_new_tokens = 10
429446 func = make_waa_rollout_func (
430447 adapter ,
431- max_steps = 3 ,
432- max_new_tokens = 10 ,
448+ max_steps = 1 ,
449+ max_new_tokens = max_new_tokens ,
433450 screenshot_retries = 1 ,
434451 screenshot_retry_delay = 0 ,
435452 )
436453 trainer = _make_mock_trainer (num_generations = 1 )
437454
438- from openadapt_evals .training import trl_rollout
455+ # Create a real small PNG so PIL.Image.open succeeds inside generate_fn
456+ img = Image .new ("RGB" , (10 , 10 ), color = "red" )
457+ buf = _io .BytesIO ()
458+ img .save (buf , format = "PNG" )
459+ valid_png = buf .getvalue ()
460+
461+ # Make adapter return a real PNG for the health check
462+ adapter .observe .return_value = BenchmarkObservation (
463+ screenshot = valid_png , raw_observation = {},
464+ )
465+
466+ # Build a mock torch module with the pieces generate_fn needs
467+ mock_torch = MagicMock ()
468+ mock_torch .no_grad .return_value .__enter__ = MagicMock (return_value = None )
469+ mock_torch .no_grad .return_value .__exit__ = MagicMock (return_value = False )
470+
471+ # model.generate returns prompt(5) + completion(10) tokens
472+ prompt_len = 5
473+ fake_completion = list (range (max_new_tokens )) # 10 tokens >= 10-1
474+ fake_seq = MagicMock ()
475+ fake_seq .__getitem__ = lambda self , idx : MagicMock (
476+ tolist = lambda : fake_completion ,
477+ )
478+ mock_gen_output = MagicMock ()
479+ mock_gen_output .sequences = [fake_seq ]
480+ mock_gen_output .scores = []
481+ trainer .model .generate .return_value = mock_gen_output
482+
483+ # Mock processor to return inputs with known prompt length
484+ mock_inputs = MagicMock ()
485+ mock_input_ids = MagicMock ()
486+ mock_input_ids .shape = [1 , prompt_len ]
487+ mock_inputs .__getitem__ = lambda self , key : (
488+ mock_input_ids if key == "input_ids" else MagicMock ()
489+ )
490+ mock_inputs .to .return_value = mock_inputs
491+ trainer .processing_class .return_value = mock_inputs
492+ trainer .processing_class .apply_chat_template .return_value = "fake prompt"
493+ trainer .processing_class .decode .return_value = '{"type": "done"}'
439494
440495 def mock_run (env , gfn , instr , tid , ms , stuck_threshold = 3 ):
496+ """Call the REAL generate_fn with mocked torch."""
441497 from openadapt_evals .adapters .rl_env import ResetConfig
442498 env .reset (config = ResetConfig (task_id = tid ))
443499
444- # Mock the generate path to simulate truncation
445- from PIL import Image
446- import io as _io
447-
448- # Create a real small PNG for PIL to open
449- img = Image .new ("RGB" , (10 , 10 ), color = "red" )
450- buf = _io .BytesIO ()
451- img .save (buf , format = "PNG" )
452- valid_png = buf .getvalue ()
453-
454- # Patch the model to return max_new_tokens - 1 tokens
455- # with nonsensical text that doesn't contain "done"
456- mock_outputs = MagicMock ()
457- mock_seqs = MagicMock ()
458- mock_seqs .__getitem__ = lambda self , idx : MagicMock (
459- tolist = lambda : list (range (9 )), # 9 tokens = max_new_tokens - 1
460- )
461- mock_outputs .sequences = [mock_seqs [0 ]]
462- mock_outputs .scores = []
463-
464- mock_inputs = MagicMock ()
465- mock_inputs .__getitem__ = lambda self , key : MagicMock (shape = [1 , 5 ])
466-
467- # Simulate truncation: call generate_fn, intercept at model level
468- # For simplicity, we'll directly test the truncation check logic
469- # by calling parse_action_json on truncated output
470- text_with_no_done = "I was thinking about clicking the butt"
471- completion_ids = list (range (9 )) # 9 >= 10-1 triggers check
472-
473- if len (completion_ids ) >= 10 - 1 : # max_new_tokens - 1
474- action = parse_action_json (text_with_no_done )
475- if action .type == "done" and "done" not in text_with_no_done .lower ():
476- import logging as _logging
477- _logging .getLogger ("openadapt_evals.training.trl_rollout" ).warning (
478- "Output truncated at %d tokens without parseable "
479- "action. Consider increasing max_new_tokens "
480- "(current: %d) or checking VRAM." ,
481- len (completion_ids ),
482- 10 ,
483- )
500+ with patch .dict ("sys.modules" , {"torch" : mock_torch }):
501+ text , ids , lps = gfn (valid_png , "test instruction" )
484502
485- return [1 ], [ 2 ], [ - 0.1 ] , 0.0
503+ return [1 ], ids , lps , 0.0
486504
487505 with caplog .at_level (logging .WARNING , logger = "openadapt_evals.training.trl_rollout" ):
488506 with patch .object (trl_rollout , "_run_episode" , side_effect = mock_run ):
489507 func (["Test task" ], trainer )
490508
491- assert any ("truncated" in r .message .lower () for r in caplog .records ), (
492- f"Expected truncation warning in logs, got: { [r .message for r in caplog .records ]} "
509+ assert any ("generation hit max_new_tokens" in r .message .lower () for r in caplog .records ), (
510+ f"Expected truncation warning from generate_fn, got: "
511+ f"{ [r .message for r in caplog .records ]} "
493512 )
494513
495- def test_truncation_no_warning_for_done (self , caplog ):
496- """Output says 'done' and hits limit -- no truncation warning."""
497- text_with_done = '{"type": "done"} I am done now'
498- completion_ids = list (range (9 ))
499- max_new_tokens = 10
514+ def test_no_truncation_warning_when_short (self , caplog ):
515+ """HF generate path: short output does NOT trigger warning."""
516+ import io as _io
517+ from PIL import Image
518+ from openadapt_evals .training import trl_rollout
519+
520+ adapter = _make_mock_adapter ()
521+ max_new_tokens = 256
522+ func = make_waa_rollout_func (
523+ adapter ,
524+ max_steps = 1 ,
525+ max_new_tokens = max_new_tokens ,
526+ screenshot_retries = 1 ,
527+ screenshot_retry_delay = 0 ,
528+ )
529+ trainer = _make_mock_trainer (num_generations = 1 )
530+
531+ img = Image .new ("RGB" , (10 , 10 ), color = "red" )
532+ buf = _io .BytesIO ()
533+ img .save (buf , format = "PNG" )
534+ valid_png = buf .getvalue ()
535+
536+ adapter .observe .return_value = BenchmarkObservation (
537+ screenshot = valid_png , raw_observation = {},
538+ )
500539
501- # Simulate the truncation check
502- if len (completion_ids ) >= max_new_tokens - 1 :
503- action = parse_action_json (text_with_done )
504- # "done" IS in text.lower(), so this should NOT fire
505- if action .type == "done" and "done" not in text_with_done .lower ():
506- logging .getLogger ("openadapt_evals.training.trl_rollout" ).warning (
507- "Output truncated"
508- )
540+ mock_torch = MagicMock ()
541+ mock_torch .no_grad .return_value .__enter__ = MagicMock (return_value = None )
542+ mock_torch .no_grad .return_value .__exit__ = MagicMock (return_value = False )
543+
544+ # Short completion — only 3 tokens, well below max_new_tokens
545+ prompt_len = 5
546+ fake_completion = list (range (3 ))
547+ fake_seq = MagicMock ()
548+ fake_seq .__getitem__ = lambda self , idx : MagicMock (
549+ tolist = lambda : fake_completion ,
550+ )
551+ mock_gen_output = MagicMock ()
552+ mock_gen_output .sequences = [fake_seq ]
553+ mock_gen_output .scores = []
554+ trainer .model .generate .return_value = mock_gen_output
555+
556+ mock_inputs = MagicMock ()
557+ mock_input_ids = MagicMock ()
558+ mock_input_ids .shape = [1 , prompt_len ]
559+ mock_inputs .__getitem__ = lambda self , key : (
560+ mock_input_ids if key == "input_ids" else MagicMock ()
561+ )
562+ mock_inputs .to .return_value = mock_inputs
563+ trainer .processing_class .return_value = mock_inputs
564+ trainer .processing_class .apply_chat_template .return_value = "fake prompt"
565+ trainer .processing_class .decode .return_value = '{"type": "done"}'
566+
567+ def mock_run (env , gfn , instr , tid , ms , stuck_threshold = 3 ):
568+ from openadapt_evals .adapters .rl_env import ResetConfig
569+ env .reset (config = ResetConfig (task_id = tid ))
570+
571+ with patch .dict ("sys.modules" , {"torch" : mock_torch }):
572+ text , ids , lps = gfn (valid_png , "test instruction" )
573+
574+ return [1 ], ids , lps , 0.0
575+
576+ with caplog .at_level (logging .WARNING , logger = "openadapt_evals.training.trl_rollout" ):
577+ with patch .object (trl_rollout , "_run_episode" , side_effect = mock_run ):
578+ func (["Test task" ], trainer )
509579
510- # No warning should have been logged
511580 assert not any (
512- "truncated " in r .message .lower ()
581+ "generation hit max_new_tokens " in r .message .lower ()
513582 for r in caplog .records
583+ ), (
584+ f"Should NOT see truncation warning for short output, got: "
585+ f"{ [r .message for r in caplog .records ]} "
514586 )
515587
516588
0 commit comments