Skip to content

Commit ba049f7

Browse files
abrichrclaude
andauthored
feat: add standalone GRPO trainer with WAADirect (no openadapt-ml dependency) (#191)
Self-contained GRPO training package that eliminates the openadapt-ml dependency for RL training. Uses direct HTTP calls to WAA Flask server (WAADirect) instead of the WAALiveAdapter + RLEnvironment stack, removing version coupling and adapter indirection. Package structure (695 LOC total): - config.py: TrainingConfig dataclass - waa_direct.py: WAADirect HTTP client (screenshot/click/type/key) - prompt.py: SYSTEM_PROMPT + build_agent_messages + parse_vlm_output_to_action - reward.py: compute_group_advantages + evaluate_milestones_screenshot - model_loader.py: load_model_and_processor (HF + PEFT) - trainer.py: GRPOTrainer with rollout collection + training loop Key design decisions: - ZERO openadapt-ml imports (self-contained, will migrate later) - max_new_tokens=2048 default (100 was catastrophically low) - Multi-format parser (Thought/Action, bare DSL, JSON) - Fresh screenshot for evaluation (not cached) - Per-step backward to avoid OOM on long trajectories - VLM judge via OpenAI API for milestone evaluation Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bd1df33 commit ba049f7

8 files changed

Lines changed: 695 additions & 0 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Standalone GRPO trainer with direct WAA HTTP integration.
2+
3+
No openadapt-ml dependency. Will migrate to openadapt-ml later.
4+
"""
5+
6+
from openadapt_evals.training.standalone.config import TrainingConfig
7+
from openadapt_evals.training.standalone.trainer import GRPOTrainer
8+
9+
__all__ = ["GRPOTrainer", "TrainingConfig"]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Training configuration for standalone GRPO trainer."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass, field
6+
7+
8+
@dataclass
9+
class TrainingConfig:
10+
"""Configuration for standalone GRPO training."""
11+
12+
model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct"
13+
load_in_4bit: bool = True
14+
lora_r: int = 16
15+
lora_alpha: int = 32
16+
lora_checkpoint: str | None = None
17+
num_rollouts_per_step: int = 8
18+
max_steps_per_episode: int = 15
19+
temperature: float = 0.7
20+
max_new_tokens: int = 2048 # 100 truncates reasoning -- keep high
21+
server_url: str = "http://localhost:5001"
22+
task_ids: list[str] = field(default_factory=list)
23+
task_dir: str | None = None
24+
screen_size: tuple[int, int] = (1920, 1080)
25+
stuck_window: int = 3
26+
learning_rate: float = 5e-6
27+
num_training_steps: int = 1000
28+
save_every_steps: int = 50
29+
output_dir: str = "checkpoints/grpo"
30+
eval_model: str = "gpt-4.1-mini"
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""HuggingFace + PEFT model loading for standalone GRPO. No openadapt-ml imports."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import Any
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def load_model_and_processor(
12+
model_name: str,
13+
*,
14+
load_in_4bit: bool = True,
15+
lora_r: int = 16,
16+
lora_alpha: int = 32,
17+
lora_checkpoint: str | None = None,
18+
) -> tuple[Any, Any]:
19+
"""Load VLM with LoRA. Returns (model, processor)."""
20+
import torch
21+
from peft import LoraConfig, PeftModel, get_peft_model
22+
from transformers import AutoProcessor
23+
24+
try:
25+
from transformers import AutoModelForImageTextToText as AutoVLM
26+
except ImportError:
27+
from transformers import AutoModelForVision2Seq as AutoVLM
28+
29+
processor = AutoProcessor.from_pretrained(model_name)
30+
load_kwargs: dict[str, Any] = {"torch_dtype": torch.bfloat16, "device_map": "auto"}
31+
if load_in_4bit:
32+
from transformers import BitsAndBytesConfig
33+
34+
load_kwargs["quantization_config"] = BitsAndBytesConfig(
35+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4",
36+
)
37+
model = AutoVLM.from_pretrained(model_name, **load_kwargs)
38+
39+
if lora_checkpoint:
40+
logger.info("Loading existing LoRA from %s", lora_checkpoint)
41+
model = PeftModel.from_pretrained(model, lora_checkpoint, is_trainable=True)
42+
else:
43+
lora_config = LoraConfig(
44+
r=lora_r, lora_alpha=lora_alpha,
45+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
46+
task_type="CAUSAL_LM",
47+
)
48+
model = get_peft_model(model, lora_config)
49+
50+
model.print_trainable_parameters()
51+
return model, processor
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Prompt construction and VLM output parsing for GRPO training.
2+
3+
Copies SYSTEM_PROMPT from openadapt-ml next_action.py so GRPO
4+
operates in the same prompt distribution as SFT. NO openadapt-ml imports.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import json as _json
10+
import logging
11+
import re
12+
from dataclasses import dataclass
13+
from typing import Any
14+
15+
logger = logging.getLogger(__name__)
16+
DEFAULT_SCREEN_SIZE: tuple[int, int] = (1920, 1080)
17+
18+
# Copied from openadapt_ml.datasets.next_action.SYSTEM_PROMPT
19+
SYSTEM_PROMPT = (
20+
"You are a GUI automation agent. Given a screenshot and a user goal, "
21+
"predict the single next action.\n\n"
22+
"COORDINATE SYSTEM:\n"
23+
"- x=0.0 is the LEFT edge, x=1.0 is the RIGHT edge\n"
24+
"- y=0.0 is the TOP edge, y=1.0 is the BOTTOM edge\n"
25+
"- To click the CENTER of an element, estimate its center position "
26+
"as a fraction of screen width/height\n"
27+
"- Example: An element in the middle of the screen would be "
28+
"approximately x=0.5, y=0.5\n\n"
29+
"ALLOWED ACTIONS (use exactly this format):\n"
30+
"- CLICK(x=0.XX, y=0.XX) \u2192 click at normalized coordinates\n"
31+
'- TYPE(text="...") \u2192 type text into the currently focused field\n'
32+
"- WAIT() \u2192 wait for UI to update\n"
33+
"- DONE() \u2192 task is complete\n\n"
34+
"RESPONSE FORMAT (required):\n"
35+
"Thought: [Brief reasoning: what element to interact with and why]\n"
36+
"Action: [Exactly one action, e.g., CLICK(x=0.35, y=0.42)]\n\n"
37+
"IMPORTANT: Output coordinates with 2 decimal places. "
38+
"Estimate the center of target elements."
39+
)
40+
41+
42+
@dataclass
43+
class SimpleAction:
44+
"""Lightweight action (no openadapt-ml dependency)."""
45+
46+
type: str = "done"
47+
x: float | None = None
48+
y: float | None = None
49+
text: str | None = None
50+
key: str | None = None
51+
52+
53+
def build_agent_messages(
54+
instruction: str, *, include_image: bool = False, action_history: str = "",
55+
) -> list[dict]:
56+
"""Build chat messages matching the SFT prompt format."""
57+
history_text = f"{action_history}\n" if action_history else ""
58+
text_content = (
59+
f"Goal: {instruction}\n\n{history_text}"
60+
"Look at the screenshot and determine the NEXT action.\n\n"
61+
"Thought: [what element to interact with and why]\n"
62+
'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
63+
)
64+
if include_image:
65+
user_content: Any = [
66+
{"type": "image"},
67+
{"type": "text", "text": text_content},
68+
]
69+
else:
70+
user_content = text_content
71+
return [
72+
{"role": "system", "content": SYSTEM_PROMPT},
73+
{"role": "user", "content": user_content},
74+
]
75+
76+
77+
def parse_vlm_output_to_action(
78+
text: str, screen_size: tuple[int, int] = DEFAULT_SCREEN_SIZE,
79+
) -> SimpleAction:
80+
"""Parse VLM output to SimpleAction. Supports Thought/Action, bare DSL, and JSON."""
81+
text = text.strip()
82+
width, height = screen_size
83+
logger.debug("Parsing VLM output (%d chars): %.200s", len(text), text)
84+
85+
# Extract from "Action: ..." format
86+
action_match = re.search(r"Action:\s*(.+)", text, re.IGNORECASE)
87+
if action_match:
88+
text = action_match.group(1).strip()
89+
90+
# JSON: {"action_type": "click", "coordinate": [x, y]}
91+
json_match = re.search(r'\{[^}]*"action_type"[^}]*\}', text)
92+
if json_match:
93+
try:
94+
d = _json.loads(json_match.group())
95+
atype = d.get("action_type", "").lower()
96+
coord = d.get("coordinate", d.get("coords", []))
97+
if atype == "click" and len(coord) >= 2:
98+
xv, yv = float(coord[0]), float(coord[1])
99+
if xv <= 1.0 and yv <= 1.0:
100+
xv, yv = xv * width, yv * height
101+
return SimpleAction(type="click", x=int(xv), y=int(yv))
102+
if atype == "type":
103+
return SimpleAction(type="type", text=d.get("text", ""))
104+
if atype in ("done", "wait"):
105+
return SimpleAction(type=atype)
106+
except Exception:
107+
pass
108+
109+
# CLICK(x=..., y=...)
110+
m = re.search(r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)", text, re.IGNORECASE)
111+
if m:
112+
xf = max(0.0, min(1.0, float(m.group(1))))
113+
yf = max(0.0, min(1.0, float(m.group(2))))
114+
return SimpleAction(type="click", x=int(xf * width), y=int(yf * height))
115+
116+
# TYPE(text="...")
117+
m = re.search(r"""TYPE\(text=["']([^"'\\]*(?:\\.[^"'\\]*)*)["']\)""", text, re.IGNORECASE)
118+
if m:
119+
t = m.group(1).replace("\\\\", "\\").replace('\\"', '"').replace("\\'", "'")
120+
return SimpleAction(type="type", text=t)
121+
122+
if re.search(r"\bWAIT\s*\(\s*\)", text, re.IGNORECASE):
123+
return SimpleAction(type="wait")
124+
if re.search(r"\bDONE\s*\(\s*\)", text, re.IGNORECASE):
125+
return SimpleAction(type="done")
126+
127+
logger.warning("Could not parse VLM output: %s. Defaulting to DONE.", text)
128+
return SimpleAction(type="done")
129+
130+
131+
def format_action_as_text(
132+
action: SimpleAction, screen_size: tuple[int, int] = DEFAULT_SCREEN_SIZE,
133+
) -> str:
134+
"""Convert SimpleAction to DSL text for log-prob computation."""
135+
width, height = screen_size
136+
if action.type == "click":
137+
xf = (action.x or 0) / width if width > 0 else 0.0
138+
yf = (action.y or 0) / height if height > 0 else 0.0
139+
return f"CLICK(x={xf:.2f}, y={yf:.2f})"
140+
if action.type == "type":
141+
escaped = (action.text or "").replace("\\", "\\\\").replace('"', '\\"')
142+
return f'TYPE(text="{escaped}")'
143+
if action.type == "wait":
144+
return "WAIT()"
145+
return "DONE()"
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Reward: group-relative advantages + VLM milestone evaluation. No openadapt-ml imports."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import Any
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def compute_group_advantages(rewards: list[float]) -> list[float]:
12+
"""GRPO group-relative advantages: (r - mean) / (std + eps)."""
13+
n = len(rewards)
14+
if n == 0:
15+
return []
16+
mean = sum(rewards) / n
17+
variance = sum((r - mean) ** 2 for r in rewards) / n
18+
std = variance**0.5
19+
if std < 1e-8:
20+
return [0.0] * n
21+
return [(r - mean) / (std + 1e-8) for r in rewards]
22+
23+
24+
def evaluate_milestones_screenshot(
25+
task_config: Any, screenshot: bytes, *, model: str = "gpt-4.1-mini",
26+
) -> float:
27+
"""VLM screenshot-only milestone evaluation. Returns passed/total [0,1]."""
28+
milestones = getattr(task_config, "milestones", [])
29+
sm = [m for m in milestones if m.check.check == "screenshot"]
30+
if not sm:
31+
return 0.0
32+
from openadapt_evals.vlm_evaluator import vlm_judge
33+
34+
passed = sum(1 for m in sm if vlm_judge(screenshot, m.check.description or "", model=model)[0])
35+
return passed / len(sm)

0 commit comments

Comments
 (0)