diff --git a/openadapt_ml/training/grpo/trainer.py b/openadapt_ml/training/grpo/trainer.py index bf7389b..88c1c86 100644 --- a/openadapt_ml/training/grpo/trainer.py +++ b/openadapt_ml/training/grpo/trainer.py @@ -98,12 +98,16 @@ def policy_gradient_loss( def _build_agent_messages( - instruction: str, *, include_image: bool = False + instruction: str, + *, + include_image: bool = False, + action_history: str = "", ) -> list[dict]: """Build chat messages for the GRPO agent. - Uses the same SYSTEM_PROMPT as SFT training so GRPO operates in - the same prompt distribution the model was warm-started on. + Uses the same SYSTEM_PROMPT and prompt format as SFT training + (``next_action.py``) so GRPO operates in the same prompt + distribution the model was warm-started on. This is the **single source of truth** for prompt construction during both rollout collection and loss computation. @@ -113,10 +117,15 @@ def _build_agent_messages( include_image: If True, include an image placeholder in the user message so ``apply_chat_template`` inserts ``<|image_pad|>`` tokens required by Qwen2.5-VL and similar VLMs. + action_history: Formatted action history from previous steps + (e.g. "Step 1: CLICK(x=0.5, y=0.3)\\nStep 2: TYPE(...)"). """ + history_text = f"{action_history}\n" if action_history else "" text_content = ( f"Goal: {instruction}\n\n" + f"{history_text}" "Look at the screenshot and determine the NEXT action.\n\n" + "Thought: [what element to interact with and why]\n" 'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]' ) if include_image: @@ -161,6 +170,37 @@ class BenchmarkAction: # type: ignore[no-redef] text = text.strip() width, height = screen_size + # Log raw output for debugging zero-reward issues + logger.debug("Parsing VLM output (%d chars): %.200s", len(text), text) + + # Extract action from "Thought: ...\nAction: ..." format (SFT output) + action_match = re.search(r"Action:\s*(.+)", text, re.IGNORECASE) + if action_match: + text = action_match.group(1).strip() + + # Try JSON format: {"action_type": "click", "coordinate": [x, y]} + json_match = re.search(r'\{[^}]*"action_type"[^}]*\}', text) + if json_match: + try: + import json as _json + action_data = _json.loads(json_match.group()) + atype = action_data.get("action_type", "").lower() + coord = action_data.get("coordinate", action_data.get("coords", [])) + if atype == "click" and len(coord) >= 2: + x_val, y_val = float(coord[0]), float(coord[1]) + # Handle both normalized (0-1) and pixel coordinates + if x_val <= 1.0 and y_val <= 1.0: + x_val, y_val = x_val * width, y_val * height + return BenchmarkAction(type="click", x=int(x_val), y=int(y_val)) + if atype == "type": + return BenchmarkAction( + type="type", text=action_data.get("text", "") + ) + if atype in ("done", "wait"): + return BenchmarkAction(type=atype) + except Exception: + pass # Fall through to regex parsing + # CLICK(x=..., y=...) m = re.search(r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)", text, re.IGNORECASE) if m: