Skip to content

Commit 7cf67d4

Browse files
abrichrclaude
andcommitted
feat: add GRPO training module for online RL
Add openadapt_ml/training/grpo/ package with: - GRPOConfig for training hyperparameters - GRPORolloutCollector connecting to openadapt-evals RLEnvironment - GRPOTrainer implementing custom GRPO loop for multimodal VLMs - Binary reward function and group-relative advantage computation - Chain-of-thought warm-up pipeline for SFT pre-training - 20 unit tests passing without GPU Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c04ac10 commit 7cf67d4

7 files changed

Lines changed: 1379 additions & 0 deletions

File tree

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""GRPO (Group Relative Policy Optimization) training module.
2+
3+
Provides online RL training for GUI agent VLMs using the GRPO algorithm.
4+
Connects to openadapt-evals RLEnvironment for rollout collection and
5+
task evaluation against live Windows Agent Arena VMs.
6+
7+
Key components:
8+
- GRPOConfig: Training configuration dataclass
9+
- GRPOTrainer: Main training loop
10+
- GRPORolloutCollector: Collects rollouts via RLEnvironment
11+
- reward functions: Binary task success + group-relative advantages
12+
- CoT warm-up: Chain-of-thought SFT before GRPO
13+
14+
Example:
15+
from openadapt_ml.training.grpo import GRPOConfig, GRPOTrainer
16+
17+
config = GRPOConfig(
18+
task_ids=["notepad_1", "settings_1"],
19+
num_training_steps=100,
20+
)
21+
trainer = GRPOTrainer(config)
22+
trainer.train()
23+
"""
24+
25+
from __future__ import annotations
26+
27+
from openadapt_ml.training.grpo.config import GRPOConfig
28+
from openadapt_ml.training.grpo.reward import (
29+
binary_task_success,
30+
compute_group_advantages,
31+
)
32+
from openadapt_ml.training.grpo.rollout_collector import (
33+
GRPORolloutCollector,
34+
Rollout,
35+
)
36+
from openadapt_ml.training.grpo.trainer import GRPOTrainer
37+
38+
__all__ = [
39+
"GRPOConfig",
40+
"GRPOTrainer",
41+
"GRPORolloutCollector",
42+
"Rollout",
43+
"binary_task_success",
44+
"compute_group_advantages",
45+
]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""GRPO training configuration.
2+
3+
Follows the same pattern as TRLTrainingConfig in trl_trainer.py, with
4+
additional fields for GRPO-specific hyperparameters and environment setup.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from dataclasses import dataclass, field
10+
11+
12+
@dataclass
13+
class GRPOConfig:
14+
"""Configuration for GRPO (Group Relative Policy Optimization) training.
15+
16+
Groups model/LoRA defaults with TRLTrainingConfig for consistency.
17+
18+
Attributes:
19+
model_name: HuggingFace model identifier.
20+
load_in_4bit: Whether to use 4-bit quantization.
21+
max_seq_length: Maximum sequence length for the model.
22+
lora_r: LoRA rank.
23+
lora_alpha: LoRA alpha scaling factor.
24+
num_rollouts_per_step: Group size N for GRPO advantage computation.
25+
max_steps_per_episode: Maximum actions per rollout episode.
26+
temperature: Sampling temperature for action generation during rollouts.
27+
kl_coef: KL divergence penalty coefficient against reference policy.
28+
server_url: URL of the WAA server for live environment interaction.
29+
task_ids: List of WAA task IDs to train on.
30+
learning_rate: Optimizer learning rate for LoRA parameter updates.
31+
gradient_accumulation_steps: Number of gradient accumulation steps.
32+
num_training_steps: Total number of GRPO training steps (outer loop).
33+
save_every_steps: Checkpoint frequency.
34+
output_dir: Directory for saving checkpoints and logs.
35+
stuck_window: Number of identical screenshots before early termination.
36+
"""
37+
38+
# Model (same defaults as TRLTrainingConfig)
39+
model_name: str = "unsloth/Qwen2.5-VL-7B-Instruct"
40+
load_in_4bit: bool = True
41+
max_seq_length: int = 4096
42+
43+
# LoRA
44+
lora_r: int = 16
45+
lora_alpha: int = 32
46+
47+
# GRPO-specific
48+
num_rollouts_per_step: int = 8 # Group size N
49+
max_steps_per_episode: int = 15
50+
temperature: float = 0.7 # Sampling temperature for rollouts
51+
kl_coef: float = 0.01 # KL divergence penalty
52+
53+
# Environment
54+
server_url: str = "http://localhost:5001"
55+
task_ids: list[str] = field(default_factory=list)
56+
57+
# Training
58+
learning_rate: float = 5e-6
59+
gradient_accumulation_steps: int = 8
60+
num_training_steps: int = 1000
61+
save_every_steps: int = 50
62+
output_dir: str = "checkpoints/grpo"
63+
64+
# Stuck detection
65+
stuck_window: int = 3
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""Chain-of-thought warm-up for GRPO training.
2+
3+
Provides utilities to annotate successful demonstration episodes with
4+
chain-of-thought reasoning, then convert them to SFT training format.
5+
This CoT SFT warm-up initializes the policy before GRPO online RL,
6+
giving the model a better starting point for action generation.
7+
8+
The two-step process:
9+
1. generate_cot_annotations(): Use a capable model to add reasoning
10+
to each step of successful demonstrations.
11+
2. build_cot_sft_samples(): Convert annotated episodes to the
12+
TRL SFT format used by trl_trainer.py.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import logging
18+
from typing import Any
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
def generate_cot_annotations(
24+
episodes: list[Any],
25+
annotator_model: str = "gpt-4o",
26+
) -> list[Any]:
27+
"""Add chain-of-thought reasoning to successful demonstrations.
28+
29+
For each step in a successful episode, uses the specified model to
30+
generate reasoning that explains the action choice given the
31+
screenshot context. This produces <think>...</think> blocks that
32+
teach the model to reason before acting.
33+
34+
Args:
35+
episodes: List of Episode objects from openadapt_ml.schema.
36+
Only successful episodes (episode.success == True) are
37+
annotated; others are returned unchanged.
38+
annotator_model: Model identifier for generating CoT annotations.
39+
Must support vision (image inputs).
40+
41+
Returns:
42+
List of episodes with reasoning fields populated on each step.
43+
Episodes that were already annotated or unsuccessful are
44+
returned unchanged.
45+
"""
46+
# Deferred import for optional dependency
47+
try:
48+
from openadapt_ml.models.api_adapter import get_api_adapter
49+
except ImportError:
50+
logger.error(
51+
"openadapt_ml.models.api_adapter not available. "
52+
"Cannot generate CoT annotations."
53+
)
54+
return episodes
55+
56+
annotated = []
57+
58+
for episode in episodes:
59+
if not getattr(episode, "success", False):
60+
annotated.append(episode)
61+
continue
62+
63+
instruction = getattr(episode, "instruction", "")
64+
steps = getattr(episode, "steps", [])
65+
66+
for step_idx, step in enumerate(steps):
67+
# Skip if already annotated
68+
if getattr(step, "reasoning", None):
69+
continue
70+
71+
screenshot_path = getattr(
72+
getattr(step, "observation", None),
73+
"screenshot_path",
74+
None,
75+
)
76+
action = getattr(step, "action", None)
77+
78+
if not screenshot_path or not action:
79+
continue
80+
81+
prompt = (
82+
f"You are analyzing step {step_idx + 1} of {len(steps)} "
83+
f"in a GUI automation task.\n\n"
84+
f"Task instruction: {instruction}\n\n"
85+
f"The action taken at this step was: {action}\n\n"
86+
"Explain in 1-2 sentences WHY this action was taken. "
87+
"Focus on what the agent sees on screen and how the "
88+
"action moves toward completing the task. "
89+
"Be concise and specific."
90+
)
91+
92+
try:
93+
adapter = get_api_adapter(annotator_model)
94+
reasoning = adapter.generate(
95+
{
96+
"images": [screenshot_path],
97+
"messages": [
98+
{"role": "user", "content": prompt},
99+
],
100+
},
101+
max_new_tokens=150,
102+
)
103+
step.reasoning = reasoning.strip()
104+
logger.debug(
105+
"Annotated step %d: %s",
106+
step_idx,
107+
step.reasoning[:80],
108+
)
109+
except Exception as e:
110+
logger.warning(
111+
"Failed to annotate step %d: %s", step_idx, e
112+
)
113+
114+
annotated.append(episode)
115+
116+
logger.info(
117+
"CoT annotation complete: %d episodes processed", len(annotated)
118+
)
119+
return annotated
120+
121+
122+
def build_cot_sft_samples(annotated_episodes: list[Any]) -> list[dict]:
123+
"""Convert CoT-annotated episodes to TRL SFT format.
124+
125+
Produces training samples where the assistant response includes a
126+
<think> block before the action, teaching the model to reason
127+
step-by-step during inference.
128+
129+
Format:
130+
User: <image>
131+
Instruction: Open Notepad and type Hello
132+
Previous actions: CLICK(x=0.05, y=0.95)
133+
134+
Assistant: <think>I see the Start menu is open. The task requires
135+
opening Notepad, so I need to search for it.</think>
136+
TYPE(text="notepad")
137+
138+
Args:
139+
annotated_episodes: Episodes with reasoning fields on steps,
140+
typically from generate_cot_annotations().
141+
142+
Returns:
143+
List of SFT sample dicts compatible with trl_trainer.py:
144+
{
145+
"images": [path],
146+
"messages": [system, user, assistant],
147+
}
148+
"""
149+
from openadapt_ml.datasets.next_action import (
150+
SYSTEM_PROMPT,
151+
format_action,
152+
)
153+
154+
samples: list[dict] = []
155+
156+
for episode in annotated_episodes:
157+
if not getattr(episode, "success", False):
158+
continue
159+
160+
instruction = getattr(episode, "instruction", "")
161+
steps = getattr(episode, "steps", [])
162+
163+
for step in steps:
164+
screenshot_path = getattr(
165+
getattr(step, "observation", None),
166+
"screenshot_path",
167+
None,
168+
)
169+
action = getattr(step, "action", None)
170+
reasoning = getattr(step, "reasoning", None)
171+
172+
if not screenshot_path or not action:
173+
continue
174+
175+
# Build action history
176+
step_index = getattr(step, "step_index", 0)
177+
prev_actions = []
178+
for prev_step in steps:
179+
prev_idx = getattr(prev_step, "step_index", 0)
180+
if prev_idx < step_index:
181+
prev_actions.append(
182+
format_action(prev_step.action)
183+
)
184+
185+
# Build user content
186+
parts = [f"Instruction: {instruction}"]
187+
if prev_actions:
188+
parts.append(
189+
"Previous actions: "
190+
+ " -> ".join(prev_actions)
191+
)
192+
user_content = "\n".join(parts)
193+
194+
# Build assistant content with CoT
195+
action_text = format_action(action)
196+
if reasoning:
197+
assistant_content = f"<think>{reasoning}</think>\n{action_text}"
198+
else:
199+
assistant_content = action_text
200+
201+
sample = {
202+
"images": [screenshot_path],
203+
"messages": [
204+
{"role": "system", "content": SYSTEM_PROMPT},
205+
{"role": "user", "content": user_content},
206+
{"role": "assistant", "content": assistant_content},
207+
],
208+
}
209+
samples.append(sample)
210+
211+
logger.info("Built %d CoT SFT samples", len(samples))
212+
return samples
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Reward functions for GRPO training.
2+
3+
Provides binary task-success rewards and group-relative advantage
4+
computation following the GRPO algorithm (Shao et al., 2024).
5+
6+
GRPO computes advantages relative to the group mean rather than using
7+
a learned value function, which is simpler and works well for sparse
8+
binary rewards (task success/failure).
9+
"""
10+
11+
from __future__ import annotations
12+
13+
14+
def binary_task_success(score: float, threshold: float = 0.5) -> float:
15+
"""Convert evaluator score to binary reward.
16+
17+
Args:
18+
score: Raw evaluator score (0.0-1.0) from WAA environment.
19+
threshold: Score at or above which the task is considered successful.
20+
21+
Returns:
22+
1.0 if score >= threshold, else 0.0.
23+
"""
24+
return 1.0 if score >= threshold else 0.0
25+
26+
27+
def compute_group_advantages(rewards: list[float]) -> list[float]:
28+
"""Compute group-relative advantages for a batch of rollout rewards.
29+
30+
GRPO normalizes rewards within each group:
31+
advantage[i] = (reward[i] - mean) / (std + eps)
32+
33+
If all rewards are identical (no variance), returns all zeros. This
34+
avoids NaN from division by zero and correctly signals that there is
35+
no gradient signal when every rollout in the group has the same outcome.
36+
37+
Args:
38+
rewards: List of scalar rewards for each rollout in the group.
39+
40+
Returns:
41+
List of advantage values, same length as rewards.
42+
"""
43+
n = len(rewards)
44+
if n == 0:
45+
return []
46+
47+
mean = sum(rewards) / n
48+
variance = sum((r - mean) ** 2 for r in rewards) / n
49+
std = variance**0.5
50+
eps = 1e-8
51+
52+
# No variance means no gradient signal: all advantages are zero
53+
if std < eps:
54+
return [0.0] * n
55+
56+
return [(r - mean) / (std + eps) for r in rewards]

0 commit comments

Comments
 (0)