Skip to content

Commit 339e5d3

Browse files
abrichrclaude
andauthored
feat: add GRPO training module with minimal TRL bridge (#34)
* docs: add experimental roadmap and evidence context to vision - Add 2x2 experimental matrix (retrieval × fine-tuning) to Core Thesis - Add evidence context to benchmark table: note it's an internal synthetic benchmark (~3 UI elements) that validates the pipeline, not real-world performance. Link to openadapt-evals for ongoing WAA/OSWorld evaluation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: use 46.7% consistently in 2x2 matrix Was showing 33-47% range which conflated preliminary (n=3) and full (n=45) results. The validated number is 46.7%. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: add GRPO training module for online RL Add openadapt_ml/training/grpo/ package with: - GRPOConfig for training hyperparameters - GRPORolloutCollector connecting to openadapt-evals RLEnvironment - GRPOTrainer implementing custom GRPO loop for multimodal VLMs - Binary reward function and group-relative advantage computation - Chain-of-thought warm-up pipeline for SFT pre-training - 20 unit tests passing without GPU Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address review findings in GRPO module - Replace copy.deepcopy(model) with LoRA state dict snapshot (prevents OOM) - Mark _compute_rollout_loss as scaffold with dummy forward pass for grad flow - Fix collect_rollout call to match RLEnvironment API (task_id in signature) - Add model.eval()/model.train() toggling around rollout/training phases - Remove unused gradient_accumulation_steps config field - Use actual screen_size from RLEnvironment instead of hardcoded 1920x1200 - Clamp CLICK coordinates to [0.0, 1.0] to prevent invalid pixel values - Validate task_ids non-empty at start of train() - Export CoT warmup functions from package __init__ - Add BenchmarkAction fallback when openadapt-evals not installed - Add 9 new tests: action parser (8) + empty task_ids validation (1) - All 29 tests passing Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: implement GRPO loss computation and fix cot_warmup dependency Implement the core _compute_rollout_loss method that was previously a NotImplementedError scaffold. The implementation: - Reconstructs VLM prompts from rollout observations - Formats actions back to DSL text via new _format_action_as_text helper - Computes log-probabilities of action tokens under current policy - Computes reference policy log-probs via PEFT disable_adapter() with fallback to manual LoRA weight swapping - Returns GRPO loss: -advantage * log_prob + kl_coef * KL penalty Also adds get_api_adapter() factory function to api_adapter.py, fixing the broken import in cot_warmup.py's generate_cot_annotations(). Additional review fixes from prior session: - Initialize _is_unsloth and _ref_lora_state in __init__ - Remove dead else branch for task_id selection - Fix total_loss device placement - LoRA-only fallback save in checkpoint - TYPE regex accepts single quotes - Coordinate clamping in _parse_vlm_output_to_action 40 tests passing (10 new: 8 format_action + 1 roundtrip + 1 api_adapter). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: deduplicate GRPO prompts via shared _build_agent_messages Extract prompt construction into _build_agent_messages() which imports SYSTEM_PROMPT from next_action.py (the SFT training prompt). This ensures the GRPO agent uses the same prompt distribution the model was warm-started on, and guarantees _make_agent_fn and _compute_rollout_loss use identical prompts (critical for correct log-prob computation). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(grpo): address critical review findings in GRPO loss computation - C-01: Store raw model output on action._grpo_raw_text for accurate loss - C-02: Separate tokenization of prompt/action with concatenation to fix BPE boundary alignment - I-01: Prefer LoRA weight swapping over disable_adapter() for reference policy (captures initial LoRA state after SFT warm-start) - I-03: Per-step gradient accumulation via immediate backward() to prevent OOM from building computation graph over all rollout steps - I-04: Fix unescape order in TYPE parser (backslash before quotes) - M-03: Pass model_name through get_api_adapter to ApiVLMAdapter - M-07: Case-insensitive CLICK/TYPE regex in _parse_vlm_output_to_action - L-01: Extract DEFAULT_SCREEN_SIZE constant, replace all hardcoded values Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(grpo): fix instruction propagation, screen size, weight swap safety - CR-01: Task instruction was never populated during GRPO rollouts. WAALiveAdapter._get_observation() does not populate raw_observation, so the agent prompt said "Goal: " with nothing after it. Fix: store instruction on Rollout dataclass (populated from env._current_task in collector), use it in both agent_fn and _compute_rollout_loss. - IM-01: Change DEFAULT_SCREEN_SIZE from 1920x1200 to 1920x1080 for consistency with baselines module and standard VM configurations. Add screen_size field to GRPOConfig so it is configurable. - IM-02: Add try/finally around LoRA weight swap in _compute_ref_log_probs. Without this, an exception during the reference forward pass permanently corrupts the model state. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(grpo): remove unused torch import in _setup_model The import torch at line 121 was flagged by ruff (F401) as unused. The surrounding code only calls .detach().clone() on tensor objects, which does not require the torch module directly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style(grpo): apply ruff formatting to GRPO module files Run ruff format on cot_warmup.py, rollout_collector.py, and trainer.py to satisfy the CI ruff formatter check. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(grpo): replace custom trainer with minimal TRL bridge Replace 809-line custom GRPO trainer with ~280 lines that: - Use standard HuggingFace AutoModelForVision2Seq + AutoProcessor + PEFT LoraConfig instead of Unsloth monkey-patching - Implement standalone GRPO loss in ~15 lines of PyTorch (clipped surrogate) instead of custom policy gradient + KL penalty - Use beta=0.0 (no KL penalty, no reference model) per DAPO/Open- Reasoner-Zero literature, eliminating weight-swap complexity - Keep per-step backward to avoid OOM on long trajectories - Use standard model.save_pretrained() for checkpointing - Document WHY standalone GRPO math vs TRL GRPOTrainer (VLM multi-turn image pixel_values not stored in token IDs) and WHEN to switch Preserves all public API: GRPOTrainer, _parse_vlm_output_to_action, _format_action_as_text, _build_agent_messages, DEFAULT_SCREEN_SIZE. All 50 tests pass (44 existing + 6 new for grpo_loss and trainer internals). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat(grpo): add E2E tests with artifact generation and architecture docs - tests/test_grpo_e2e.py: 5 E2E tests (training loop, rollout collection, loss convergence, weight diff, mathematical properties) using tiny mock VLM. Produces 65+ artifacts (JSON traces, PNGs, checkpoints, summaries). - scripts/grpo_e2e_report.py: CLI report generator for test artifacts (text + optional HTML output). - docs/grpo_e2e_test_design.md: design rationale for E2E test approach - docs/grpo_architecture_analysis.md: analysis of custom vs TRL-based GRPO - docs/grpo_trl_rewrite_draft.py: TRL v0.29.0 integration research - docs/strategic_analysis_evals_ml_synergy.md: business/economics analysis Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(grpo): address self-review findings (BUG-01, CLEAN-01 through -05) - Rename grpo_loss to policy_gradient_loss with honest docstring: single-epoch on-policy means ratio=1.0, clipping never fires, this is REINFORCE with group-relative advantages. Keep grpo_loss as backwards-compatible alias. - Add public aliases: parse_vlm_output_to_action, format_action_as_text (drop underscore prefix for public API) - Export policy_gradient_loss and public functions from __init__.py - Remove unused config fields: kl_coef (was 0.01 but never used with beta=0), max_seq_length (never referenced) - Fix model_name default: Qwen/Qwen2.5-VL-7B-Instruct (not unsloth variant) - Fix trivial test assertion: grad_norm > 0 (was >= 0, always true) - Update loss tests to verify gradient direction, not just loss sign - Add test_public_api_exports for new public names 56 tests pass (51 unit + 5 E2E). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 70a0c49 commit 339e5d3

15 files changed

Lines changed: 5474 additions & 2 deletions

docs/grpo_architecture_analysis.md

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# GRPO Architecture Analysis: Custom vs TRL-Based Approach
2+
3+
## The Problem: 26 Issues From One Root Cause
4+
5+
After a comprehensive review of our custom GRPO trainer (~809 lines), we identified
6+
26 issues (7 critical, 8 important, 7 medium, 4 low). The sheer count is a code smell
7+
pointing to an architectural problem rather than implementation bugs.
8+
9+
**Root cause**: We wrote a custom GRPO trainer that reimplements what TRL now provides
10+
natively, while also tightly coupling RL math with WAA-specific glue code.
11+
12+
## Breakdown of Our 809-Line Trainer
13+
14+
| Category | Lines | What It Does |
15+
|----------|-------|-------------|
16+
| GRPO Math | ~190 | Advantage computation, KL penalty, policy gradient loss, reference policy |
17+
| Infrastructure/Glue | ~180 | Model loading, LoRA setup, optimizer, checkpointing, training loop |
18+
| Unique to Our Use Case | ~400+ | Multi-turn rollout processing, DSL parsing, prompt formatting, observation handling |
19+
20+
## What TRL v0.29.0 Now Provides
21+
22+
TRL's GRPOTrainer (as of Feb 2026) supports:
23+
24+
1. **Multi-turn rollouts** via `rollout_func` (v0.29.0) — you provide a custom function
25+
that replaces TRL's generation loop. Returns `prompt_ids`, `completion_ids`, `logprobs`.
26+
A Wordle example shows 6-turn interactive loops.
27+
28+
2. **`environment_factory`** (v0.29.0) — stateful environments with `reset()` and
29+
arbitrary methods as tools. One instance per rollout.
30+
31+
3. **Multimodal VLMs** including Qwen2.5-VL — natively supported since v0.20.0.
32+
33+
4. **Custom reward functions** — pass a callable, supports async, multiple functions,
34+
environment access, extra rollout fields forwarded as kwargs.
35+
36+
5. **LoRA + quantization** — standard PEFT integration.
37+
38+
6. **Gradient accumulation** — standard HF Trainer mechanisms + `steps_per_generation`.
39+
40+
7. **Advanced loss variants**`dapo`, `dr_grpo`, `bnpo`, asymmetric clipping, Liger kernel fusion.
41+
42+
## Which Issues Vanish With TRL?
43+
44+
~14 of 26 issues are eliminated by delegating to TRL:
45+
46+
- **CR-03** (custom GRPO duplicates TRL): Eliminated by definition
47+
- **CR-07** (untested training loop): TRL is battle-tested
48+
- **IM-03** (no error handling in rollouts): TRL handles generation errors
49+
- **IM-05** (prompt misalignment risk): TRL manages tokenization
50+
- **IM-06** (monkey-patch Unsloth loading): Use TRL's standard model loading
51+
- **IM-07** (LoRA param capture fragile): TRL handles reference policy
52+
- **MD-01** (no gradient clipping): TRL includes it
53+
- **MD-02** (no LR scheduler): TRL includes standard schedulers
54+
- **MD-03** (no WandB logging): TRL integrates with all HF loggers
55+
- **MD-04** (hardcoded AdamW): TRL supports all optimizers
56+
- **MD-05** (no multi-GPU): TRL + accelerate/DeepSpeed handles this
57+
- **MD-06** (no mixed precision): TRL handles bf16/fp16
58+
- **LO-01** (verbose step logging): TRL's logging is configurable
59+
- **LO-02** (no TensorBoard): TRL integrates natively
60+
61+
## The Key Gap: Multi-Turn Interactive Rollouts
62+
63+
TRL is fundamentally **single-turn**: prompt -> completion -> reward. Even with
64+
`rollout_func`, the advantage is computed at the trajectory level (one reward per
65+
complete rollout), not per-step.
66+
67+
But this actually **matches our use case**:
68+
- WebAgent-R1 uses binary task-success rewards (0 or 1)
69+
- GRPO computes group-relative advantages across N trajectories of the same task
70+
- We don't need per-step credit assignment — trajectory-level reward is sufficient
71+
72+
The `rollout_func` approach lets us:
73+
1. Call our `RLEnvironment.collect_rollout()` to get interactive multi-step trajectories
74+
2. Return the concatenated token IDs and log-probs to TRL
75+
3. Let TRL handle advantage computation, clipping, KL penalty, optimization
76+
77+
## Proposed Architecture
78+
79+
```
80+
TRL GRPOTrainer <- standard, maintained, tested (0 lines from us)
81+
|
82+
+-- rollout_func: <- ~100 lines (our custom rollout function)
83+
| Uses RLEnvironment to collect interactive multi-step trajectories
84+
| Returns prompt_ids, completion_ids, logprobs
85+
|
86+
+-- reward_func: <- ~20 lines (already exists in reward.py)
87+
| binary_task_success() + compute_group_advantages()
88+
|
89+
+-- RolloutCollector <- ~150 lines (already exists)
90+
| collect_group() orchestrates N rollouts per task
91+
|
92+
+-- RLEnvironment <- openadapt-evals (already exists, PR #73)
93+
reset() / step() / observe() / evaluate()
94+
```
95+
96+
**Our code shrinks from ~800 lines to ~200 lines** of genuine domain-specific logic:
97+
- `rollout_func`: Bridges TRL's generation loop with our interactive environment
98+
- Action DSL parsing (CLICK/TYPE/WAIT/DONE)
99+
- Prompt construction for multi-turn VLM interaction
100+
- Reward function (already exists)
101+
102+
## What About WebAgent-R1 and Agent-R1?
103+
104+
Both build on **veRL** (ByteDance's RL framework), NOT TRL. They implement their own
105+
multi-turn GRPO from scratch. Key results:
106+
- WebAgent-R1: Qwen-2.5-3B went 6.1% -> 33.9% on WebArena-Lite
107+
- Agent-R1: Supports PPO, GRPO, REINFORCE++ with per-tool-call process rewards
108+
109+
We could also consider veRL, but TRL has better ecosystem integration (HF Hub, PEFT,
110+
quantization, vLLM) and the `rollout_func` API is flexible enough for our needs.
111+
112+
## Standalone GRPO Math (Fallback Option)
113+
114+
If TRL's `rollout_func` proves too constraining, the GRPO math is ~30 lines of PyTorch:
115+
116+
```python
117+
# Advantage (group-normalized)
118+
mean_r = rewards.reshape(-1, G).mean(dim=1, keepdim=True)
119+
std_r = rewards.reshape(-1, G).std(dim=1, keepdim=True)
120+
advantages = (rewards - mean_r.repeat(1, G).flatten()) / (std_r.repeat(1, G).flatten() + 1e-4)
121+
122+
# KL penalty (Schulman 2020 approximation)
123+
x = ref_logps - current_logps
124+
kl = torch.exp(x) - x - 1
125+
126+
# Clipped surrogate loss
127+
ratio = torch.exp(current_logps - old_logps)
128+
clipped = torch.clamp(ratio, 1 - eps, 1 + eps)
129+
loss = -torch.min(ratio * advantages, clipped * advantages) + beta * kl
130+
```
131+
132+
This gives us full control while still eliminating the infrastructure/glue code
133+
by using HF Trainer for the training loop.
134+
135+
## Recommendation
136+
137+
1. **Merge PR #73** (openadapt-evals RL environment) — stable foundation, CI passing
138+
2. **Don't merge PR #34 as-is** — the custom trainer has too many issues
139+
3. **Rewrite GRPO module** as thin TRL adapter using `rollout_func`:
140+
- Keep: rollout_collector.py, reward.py, config.py, cot_warmup.py
141+
- Replace: trainer.py (800 lines -> ~200 lines)
142+
- Delete: All custom GRPO math, model loading, optimizer, checkpointing
143+
4. **Close ~14 GitHub issues** that become N/A with TRL delegation
144+
145+
## TRL Version Compatibility Note
146+
147+
TRL v0.29.0 `rollout_func` requires `transformers>=5.2.0`. Verify this works with
148+
Unsloth and our quantization setup before committing to this path.
149+
150+
## References
151+
152+
- [TRL GRPOTrainer docs](https://huggingface.co/docs/trl/main/en/grpo_trainer)
153+
- [TRL OpenEnv integration](https://huggingface.co/docs/trl/main/en/openenv)
154+
- [TRL v0.29.0 release](https://github.com/huggingface/trl/releases/tag/v0.29.0)
155+
- [WebAgent-R1 paper](https://arxiv.org/abs/2505.16421)
156+
- [Agent-R1 (veRL-based)](https://github.com/0russwest0/Agent-R1)
157+
- GitHub issues: openadapt-ml #35-#50, #42 (tracking)
158+
- GitHub issues: openadapt-evals #76-#78

docs/grpo_e2e_test_design.md

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# GRPO E2E Test Design
2+
3+
## Date: 2026-03-02
4+
5+
## Problem
6+
7+
The GRPO trainer was recently rewritten. The existing tests in `tests/test_grpo.py` are
8+
unit tests that mock everything and only verify individual functions in isolation. We need
9+
end-to-end tests that exercise the full training loop and produce artifacts a human can
10+
inspect to verify correctness.
11+
12+
## What a human reviewer needs to see
13+
14+
1. **Did the training loop run without errors?** -- test report with pass/fail, duration,
15+
error traces.
16+
2. **Did the model weights change?** -- LoRA parameter diff (L2 norm of delta) before vs
17+
after training. If weights did not change, training is broken.
18+
3. **Were rollouts collected and rewards computed?** -- rollout traces showing the sequence
19+
of (screenshot, action, reward) for each rollout.
20+
4. **Is the loss signal reasonable?** -- per-step metrics: loss, reward_mean,
21+
advantage stats, gradient norm.
22+
5. **Can the checkpoint be saved and reloaded?** -- verify the saved LoRA adapter can be
23+
loaded back.
24+
6. **Does the GRPO loss function actually drive policy toward high-reward actions?** --
25+
synthetic convergence test with controlled log-probs and rewards.
26+
27+
## Design Options Considered
28+
29+
### Option A: pytest with artifact directory
30+
- Standard pytest tests write artifacts to `test_artifacts/grpo_e2e/`.
31+
- Pros: CI integration, no extra dependencies, familiar.
32+
- Cons: artifacts are just files on disk; need separate step to view.
33+
34+
### Option B: Standalone script
35+
- `scripts/run_e2e_test.py` with HTML report.
36+
- Pros: rich output, self-contained.
37+
- Cons: does not integrate with CI.
38+
39+
### Option C: pytest + HTML report plugin (pytest-html)
40+
- Best of both worlds but adds a dependency.
41+
42+
### Option D: pytest + artifact directory + separate summary script
43+
- pytest writes artifacts; `scripts/grpo_e2e_report.py` reads them and prints a
44+
formatted summary (or generates HTML).
45+
- Pros: separation of concerns, can re-run report without re-running tests, CI-friendly.
46+
- Cons: two invocations.
47+
48+
### Chosen: Option D
49+
50+
Reasoning:
51+
- The user wants to "look at" results -- a summary script can print a clean, readable
52+
report without adding pytest-html as a dependency.
53+
- Tests work in CI (pytest) and locally (run report script after).
54+
- Artifacts tell the full story: JSON metrics, PNG screenshots, rollout traces.
55+
- Report script can be extended later to generate HTML without changing tests.
56+
57+
## Test Architecture
58+
59+
### Mock Strategy
60+
61+
We do NOT load a real Qwen2.5-VL model (too slow, too large). Instead:
62+
63+
1. **Mock model**: A tiny `nn.Module` with a single linear layer + LoRA-like trainable
64+
params. It accepts "input_ids" and returns logits. This lets us test that gradients
65+
flow and weights update without needing a 7B model.
66+
2. **Mock processor**: Returns pre-built tensors. Has `apply_chat_template`,
67+
`decode`, and `__call__` methods.
68+
3. **Mock environment**: Generates synthetic screenshots (colored rectangles with text
69+
via PIL), returns mock `RolloutStep` objects with realistic `BenchmarkObservation`
70+
and `BenchmarkAction` data. Reward is deterministic based on the action.
71+
4. **Mock rollout collector**: Replaces `GRPORolloutCollector` -- returns pre-built
72+
`Rollout` objects with mock steps that contain PNG screenshot bytes.
73+
74+
This way:
75+
- The training loop (optimizer, loss computation, checkpointing) is exercised for real.
76+
- Artifacts contain visually meaningful screenshots.
77+
- Tests run in < 60s on CPU.
78+
79+
### Tests
80+
81+
1. **`test_e2e_training_loop_mock`** -- Full loop: 2 training steps, 2 rollouts each.
82+
Verifies weights change, loss is computed, checkpoint is saved and loadable.
83+
84+
2. **`test_e2e_rollout_collection_mock`** -- Collects rollouts from mock environment,
85+
saves traces (JSON) and screenshots (PNG) as artifacts.
86+
87+
3. **`test_e2e_grpo_loss_convergence`** -- Synthetic test: creates fake log-probs
88+
(as trainable parameters) and rewards, runs GRPO loss + optimizer for 50 steps,
89+
verifies the "policy" shifts probability toward high-reward actions.
90+
91+
### Artifacts Written
92+
93+
```
94+
test_artifacts/grpo_e2e/<timestamp>/
95+
test_report.json -- overall pass/fail, timing, errors
96+
training_log.json -- per-step metrics from the training loop
97+
rollout_traces/
98+
step_0_rollout_0.json -- per-rollout trace
99+
step_0_rollout_0_screenshot_0.png
100+
...
101+
model_diff.json -- LoRA weight delta stats
102+
checkpoint/ -- saved LoRA adapter
103+
convergence/
104+
loss_history.json -- loss values over 50 synthetic steps
105+
advantage_policy.json -- policy probabilities over time
106+
summary.txt -- human-readable summary
107+
```
108+
109+
### Report Script
110+
111+
`scripts/grpo_e2e_report.py` reads the artifact directory and prints:
112+
- Test status (pass/fail per test)
113+
- Training metrics summary
114+
- Model weight change (did LoRA params move?)
115+
- Convergence check (did loss decrease in synthetic test?)
116+
- File listing of all artifacts
117+
118+
Uses `fire` for CLI: `python scripts/grpo_e2e_report.py <artifact_dir>`

0 commit comments

Comments
 (0)