diff --git a/openadapt_evals/training/trl_rollout.py b/openadapt_evals/training/trl_rollout.py index b2fe74c..b1c1c09 100644 --- a/openadapt_evals/training/trl_rollout.py +++ b/openadapt_evals/training/trl_rollout.py @@ -416,6 +416,7 @@ def make_waa_rollout_func( # (needs the trainer's model and processor which aren't available yet). _outlines_state: dict[str, Any] = {"generator": None, "attempted": False} _prompt_logged: list[bool] = [False] # log the prompt once for diagnostics + _output_logged: list[bool] = [False] # log first generation output def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]: """TRL GRPOTrainer rollout function. @@ -517,13 +518,49 @@ def generate_fn(screenshot_bytes: bytes, instruction: str): messages, tokenize=False, add_generation_prompt=True ) - # Log the prompt on first call so operators can verify - # the correct format is being used (DSL, not JSON). + # Comprehensive prompt diagnostics on first call. + # This logs everything needed to debug prompt construction: + # raw messages, rendered text, image presence, generation config. if not _prompt_logged[0]: _prompt_logged[0] = True + # 1. Raw messages (before chat template) + for i, msg in enumerate(messages): + role = msg.get("role", "?") + content = msg.get("content", "") + if isinstance(content, list): + types = [c.get("type", "?") for c in content] + text_parts = [c.get("text", "")[:200] for c in content + if c.get("type") == "text"] + logger.info( + "TRL prompt msg[%d] role=%s content_types=%s " + "text=%.200s", + i, role, types, " ".join(text_parts), + ) + else: + logger.info( + "TRL prompt msg[%d] role=%s content=%.500s", + i, role, content, + ) + # 2. Rendered text (after chat template) — full prompt + logger.info( + "TRL prompt text_input (%d chars): %s", + len(text_input), text_input[:2000], + ) + # 3. Image info + logger.info( + "TRL prompt image: mode=%s size=%s format=%s", + getattr(img, "mode", "?"), + getattr(img, "size", "?"), + getattr(img, "format", "?"), + ) + # 4. Generation config logger.info( - "TRL rollout prompt (first 300 chars of text_input): %.300s", - text_input, + "TRL generation config: max_new_tokens=%d " + "temperature=%s do_sample=True " + "constrained=%s model_type=%s device=%s", + max_new_tokens, temperature, + outlines_gen is not None, + type(model).__name__, device, ) # --- Constrained decoding path (Outlines) --- @@ -581,6 +618,25 @@ def generate_fn(screenshot_bytes: bytes, instruction: str): text = processor.decode(completion_ids, skip_special_tokens=True) + # Log first generation output for debugging + if not _output_logged[0]: + _output_logged[0] = True + logger.info( + "TRL first generation output (%d tokens): %.500s", + len(completion_ids), text, + ) + # Also log input shape for vision tensor debugging + logger.info( + "TRL input shapes: input_ids=%s attention_mask=%s " + "pixel_values=%s image_grid_thw=%s", + inputs.get("input_ids", torch.tensor([])).shape, + inputs.get("attention_mask", torch.tensor([])).shape, + inputs.get("pixel_values", torch.tensor([])).shape + if "pixel_values" in inputs else "MISSING", + inputs.get("image_grid_thw", torch.tensor([])).shape + if "image_grid_thw" in inputs else "MISSING", + ) + # Truncation warning — detect when output was cut off if len(completion_ids) >= max_new_tokens - 1: logger.warning(