Skip to content

Commit 8e3bc45

Browse files
abrichrclaude
andauthored
fix: comprehensive prompt diagnostics for debugging garbage output (#248)
Adds detailed one-time logging to help debug the persistent garbage output issue: 1. Raw messages (role, content types, text preview) before chat template 2. Full rendered text_input (2000 chars, not 300) 3. Image metadata (mode, size, format) 4. Generation config (max_new_tokens, temperature, constrained, model type) 5. First generation output (500 chars + token count) 6. Input tensor shapes (input_ids, attention_mask, pixel_values, image_grid_thw) The tensor shape logging is critical: if pixel_values is MISSING, the model isn't seeing the screenshot — which would explain degenerate output regardless of prompt correctness. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4e49c80 commit 8e3bc45

1 file changed

Lines changed: 60 additions & 4 deletions

File tree

openadapt_evals/training/trl_rollout.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def make_waa_rollout_func(
416416
# (needs the trainer's model and processor which aren't available yet).
417417
_outlines_state: dict[str, Any] = {"generator": None, "attempted": False}
418418
_prompt_logged: list[bool] = [False] # log the prompt once for diagnostics
419+
_output_logged: list[bool] = [False] # log first generation output
419420

420421
def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
421422
"""TRL GRPOTrainer rollout function.
@@ -517,13 +518,49 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
517518
messages, tokenize=False, add_generation_prompt=True
518519
)
519520

520-
# Log the prompt on first call so operators can verify
521-
# the correct format is being used (DSL, not JSON).
521+
# Comprehensive prompt diagnostics on first call.
522+
# This logs everything needed to debug prompt construction:
523+
# raw messages, rendered text, image presence, generation config.
522524
if not _prompt_logged[0]:
523525
_prompt_logged[0] = True
526+
# 1. Raw messages (before chat template)
527+
for i, msg in enumerate(messages):
528+
role = msg.get("role", "?")
529+
content = msg.get("content", "")
530+
if isinstance(content, list):
531+
types = [c.get("type", "?") for c in content]
532+
text_parts = [c.get("text", "")[:200] for c in content
533+
if c.get("type") == "text"]
534+
logger.info(
535+
"TRL prompt msg[%d] role=%s content_types=%s "
536+
"text=%.200s",
537+
i, role, types, " ".join(text_parts),
538+
)
539+
else:
540+
logger.info(
541+
"TRL prompt msg[%d] role=%s content=%.500s",
542+
i, role, content,
543+
)
544+
# 2. Rendered text (after chat template) — full prompt
545+
logger.info(
546+
"TRL prompt text_input (%d chars): %s",
547+
len(text_input), text_input[:2000],
548+
)
549+
# 3. Image info
550+
logger.info(
551+
"TRL prompt image: mode=%s size=%s format=%s",
552+
getattr(img, "mode", "?"),
553+
getattr(img, "size", "?"),
554+
getattr(img, "format", "?"),
555+
)
556+
# 4. Generation config
524557
logger.info(
525-
"TRL rollout prompt (first 300 chars of text_input): %.300s",
526-
text_input,
558+
"TRL generation config: max_new_tokens=%d "
559+
"temperature=%s do_sample=True "
560+
"constrained=%s model_type=%s device=%s",
561+
max_new_tokens, temperature,
562+
outlines_gen is not None,
563+
type(model).__name__, device,
527564
)
528565

529566
# --- Constrained decoding path (Outlines) ---
@@ -581,6 +618,25 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
581618

582619
text = processor.decode(completion_ids, skip_special_tokens=True)
583620

621+
# Log first generation output for debugging
622+
if not _output_logged[0]:
623+
_output_logged[0] = True
624+
logger.info(
625+
"TRL first generation output (%d tokens): %.500s",
626+
len(completion_ids), text,
627+
)
628+
# Also log input shape for vision tensor debugging
629+
logger.info(
630+
"TRL input shapes: input_ids=%s attention_mask=%s "
631+
"pixel_values=%s image_grid_thw=%s",
632+
inputs.get("input_ids", torch.tensor([])).shape,
633+
inputs.get("attention_mask", torch.tensor([])).shape,
634+
inputs.get("pixel_values", torch.tensor([])).shape
635+
if "pixel_values" in inputs else "MISSING",
636+
inputs.get("image_grid_thw", torch.tensor([])).shape
637+
if "image_grid_thw" in inputs else "MISSING",
638+
)
639+
584640
# Truncation warning — detect when output was cut off
585641
if len(completion_ids) >= max_new_tokens - 1:
586642
logger.warning(

0 commit comments

Comments
 (0)