Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions openadapt_evals/training/trl_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def make_waa_rollout_func(
# (needs the trainer's model and processor which aren't available yet).
_outlines_state: dict[str, Any] = {"generator": None, "attempted": False}
_prompt_logged: list[bool] = [False] # log the prompt once for diagnostics
_output_logged: list[bool] = [False] # log first generation output

def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
"""TRL GRPOTrainer rollout function.
Expand Down Expand Up @@ -517,13 +518,49 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
messages, tokenize=False, add_generation_prompt=True
)

# Log the prompt on first call so operators can verify
# the correct format is being used (DSL, not JSON).
# Comprehensive prompt diagnostics on first call.
# This logs everything needed to debug prompt construction:
# raw messages, rendered text, image presence, generation config.
if not _prompt_logged[0]:
_prompt_logged[0] = True
# 1. Raw messages (before chat template)
for i, msg in enumerate(messages):
role = msg.get("role", "?")
content = msg.get("content", "")
if isinstance(content, list):
types = [c.get("type", "?") for c in content]
text_parts = [c.get("text", "")[:200] for c in content
if c.get("type") == "text"]
logger.info(
"TRL prompt msg[%d] role=%s content_types=%s "
"text=%.200s",
i, role, types, " ".join(text_parts),
)
else:
logger.info(
"TRL prompt msg[%d] role=%s content=%.500s",
i, role, content,
)
# 2. Rendered text (after chat template) — full prompt
logger.info(
"TRL prompt text_input (%d chars): %s",
len(text_input), text_input[:2000],
)
# 3. Image info
logger.info(
"TRL prompt image: mode=%s size=%s format=%s",
getattr(img, "mode", "?"),
getattr(img, "size", "?"),
getattr(img, "format", "?"),
)
# 4. Generation config
logger.info(
"TRL rollout prompt (first 300 chars of text_input): %.300s",
text_input,
"TRL generation config: max_new_tokens=%d "
"temperature=%s do_sample=True "
"constrained=%s model_type=%s device=%s",
max_new_tokens, temperature,
outlines_gen is not None,
type(model).__name__, device,
)

# --- Constrained decoding path (Outlines) ---
Expand Down Expand Up @@ -581,6 +618,25 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):

text = processor.decode(completion_ids, skip_special_tokens=True)

# Log first generation output for debugging
if not _output_logged[0]:
_output_logged[0] = True
logger.info(
"TRL first generation output (%d tokens): %.500s",
len(completion_ids), text,
)
# Also log input shape for vision tensor debugging
logger.info(
"TRL input shapes: input_ids=%s attention_mask=%s "
"pixel_values=%s image_grid_thw=%s",
inputs.get("input_ids", torch.tensor([])).shape,
inputs.get("attention_mask", torch.tensor([])).shape,
inputs.get("pixel_values", torch.tensor([])).shape
if "pixel_values" in inputs else "MISSING",
inputs.get("image_grid_thw", torch.tensor([])).shape
if "image_grid_thw" in inputs else "MISSING",
)

# Truncation warning — detect when output was cut off
if len(completion_ids) >= max_new_tokens - 1:
logger.warning(
Expand Down
Loading