Skip to content

Commit 31cadf8

Browse files
abrichrclaude
andcommitted
feat(grpo): add E2E tests with artifact generation and architecture docs
- tests/test_grpo_e2e.py: 5 E2E tests (training loop, rollout collection, loss convergence, weight diff, mathematical properties) using tiny mock VLM. Produces 65+ artifacts (JSON traces, PNGs, checkpoints, summaries). - scripts/grpo_e2e_report.py: CLI report generator for test artifacts (text + optional HTML output). - docs/grpo_e2e_test_design.md: design rationale for E2E test approach - docs/grpo_architecture_analysis.md: analysis of custom vs TRL-based GRPO - docs/grpo_trl_rewrite_draft.py: TRL v0.29.0 integration research - docs/strategic_analysis_evals_ml_synergy.md: business/economics analysis Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3ebcc36 commit 31cadf8

6 files changed

Lines changed: 3527 additions & 0 deletions

docs/grpo_architecture_analysis.md

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# GRPO Architecture Analysis: Custom vs TRL-Based Approach
2+
3+
## The Problem: 26 Issues From One Root Cause
4+
5+
After a comprehensive review of our custom GRPO trainer (~809 lines), we identified
6+
26 issues (7 critical, 8 important, 7 medium, 4 low). The sheer count is a code smell
7+
pointing to an architectural problem rather than implementation bugs.
8+
9+
**Root cause**: We wrote a custom GRPO trainer that reimplements what TRL now provides
10+
natively, while also tightly coupling RL math with WAA-specific glue code.
11+
12+
## Breakdown of Our 809-Line Trainer
13+
14+
| Category | Lines | What It Does |
15+
|----------|-------|-------------|
16+
| GRPO Math | ~190 | Advantage computation, KL penalty, policy gradient loss, reference policy |
17+
| Infrastructure/Glue | ~180 | Model loading, LoRA setup, optimizer, checkpointing, training loop |
18+
| Unique to Our Use Case | ~400+ | Multi-turn rollout processing, DSL parsing, prompt formatting, observation handling |
19+
20+
## What TRL v0.29.0 Now Provides
21+
22+
TRL's GRPOTrainer (as of Feb 2026) supports:
23+
24+
1. **Multi-turn rollouts** via `rollout_func` (v0.29.0) — you provide a custom function
25+
that replaces TRL's generation loop. Returns `prompt_ids`, `completion_ids`, `logprobs`.
26+
A Wordle example shows 6-turn interactive loops.
27+
28+
2. **`environment_factory`** (v0.29.0) — stateful environments with `reset()` and
29+
arbitrary methods as tools. One instance per rollout.
30+
31+
3. **Multimodal VLMs** including Qwen2.5-VL — natively supported since v0.20.0.
32+
33+
4. **Custom reward functions** — pass a callable, supports async, multiple functions,
34+
environment access, extra rollout fields forwarded as kwargs.
35+
36+
5. **LoRA + quantization** — standard PEFT integration.
37+
38+
6. **Gradient accumulation** — standard HF Trainer mechanisms + `steps_per_generation`.
39+
40+
7. **Advanced loss variants**`dapo`, `dr_grpo`, `bnpo`, asymmetric clipping, Liger kernel fusion.
41+
42+
## Which Issues Vanish With TRL?
43+
44+
~14 of 26 issues are eliminated by delegating to TRL:
45+
46+
- **CR-03** (custom GRPO duplicates TRL): Eliminated by definition
47+
- **CR-07** (untested training loop): TRL is battle-tested
48+
- **IM-03** (no error handling in rollouts): TRL handles generation errors
49+
- **IM-05** (prompt misalignment risk): TRL manages tokenization
50+
- **IM-06** (monkey-patch Unsloth loading): Use TRL's standard model loading
51+
- **IM-07** (LoRA param capture fragile): TRL handles reference policy
52+
- **MD-01** (no gradient clipping): TRL includes it
53+
- **MD-02** (no LR scheduler): TRL includes standard schedulers
54+
- **MD-03** (no WandB logging): TRL integrates with all HF loggers
55+
- **MD-04** (hardcoded AdamW): TRL supports all optimizers
56+
- **MD-05** (no multi-GPU): TRL + accelerate/DeepSpeed handles this
57+
- **MD-06** (no mixed precision): TRL handles bf16/fp16
58+
- **LO-01** (verbose step logging): TRL's logging is configurable
59+
- **LO-02** (no TensorBoard): TRL integrates natively
60+
61+
## The Key Gap: Multi-Turn Interactive Rollouts
62+
63+
TRL is fundamentally **single-turn**: prompt -> completion -> reward. Even with
64+
`rollout_func`, the advantage is computed at the trajectory level (one reward per
65+
complete rollout), not per-step.
66+
67+
But this actually **matches our use case**:
68+
- WebAgent-R1 uses binary task-success rewards (0 or 1)
69+
- GRPO computes group-relative advantages across N trajectories of the same task
70+
- We don't need per-step credit assignment — trajectory-level reward is sufficient
71+
72+
The `rollout_func` approach lets us:
73+
1. Call our `RLEnvironment.collect_rollout()` to get interactive multi-step trajectories
74+
2. Return the concatenated token IDs and log-probs to TRL
75+
3. Let TRL handle advantage computation, clipping, KL penalty, optimization
76+
77+
## Proposed Architecture
78+
79+
```
80+
TRL GRPOTrainer <- standard, maintained, tested (0 lines from us)
81+
|
82+
+-- rollout_func: <- ~100 lines (our custom rollout function)
83+
| Uses RLEnvironment to collect interactive multi-step trajectories
84+
| Returns prompt_ids, completion_ids, logprobs
85+
|
86+
+-- reward_func: <- ~20 lines (already exists in reward.py)
87+
| binary_task_success() + compute_group_advantages()
88+
|
89+
+-- RolloutCollector <- ~150 lines (already exists)
90+
| collect_group() orchestrates N rollouts per task
91+
|
92+
+-- RLEnvironment <- openadapt-evals (already exists, PR #73)
93+
reset() / step() / observe() / evaluate()
94+
```
95+
96+
**Our code shrinks from ~800 lines to ~200 lines** of genuine domain-specific logic:
97+
- `rollout_func`: Bridges TRL's generation loop with our interactive environment
98+
- Action DSL parsing (CLICK/TYPE/WAIT/DONE)
99+
- Prompt construction for multi-turn VLM interaction
100+
- Reward function (already exists)
101+
102+
## What About WebAgent-R1 and Agent-R1?
103+
104+
Both build on **veRL** (ByteDance's RL framework), NOT TRL. They implement their own
105+
multi-turn GRPO from scratch. Key results:
106+
- WebAgent-R1: Qwen-2.5-3B went 6.1% -> 33.9% on WebArena-Lite
107+
- Agent-R1: Supports PPO, GRPO, REINFORCE++ with per-tool-call process rewards
108+
109+
We could also consider veRL, but TRL has better ecosystem integration (HF Hub, PEFT,
110+
quantization, vLLM) and the `rollout_func` API is flexible enough for our needs.
111+
112+
## Standalone GRPO Math (Fallback Option)
113+
114+
If TRL's `rollout_func` proves too constraining, the GRPO math is ~30 lines of PyTorch:
115+
116+
```python
117+
# Advantage (group-normalized)
118+
mean_r = rewards.reshape(-1, G).mean(dim=1, keepdim=True)
119+
std_r = rewards.reshape(-1, G).std(dim=1, keepdim=True)
120+
advantages = (rewards - mean_r.repeat(1, G).flatten()) / (std_r.repeat(1, G).flatten() + 1e-4)
121+
122+
# KL penalty (Schulman 2020 approximation)
123+
x = ref_logps - current_logps
124+
kl = torch.exp(x) - x - 1
125+
126+
# Clipped surrogate loss
127+
ratio = torch.exp(current_logps - old_logps)
128+
clipped = torch.clamp(ratio, 1 - eps, 1 + eps)
129+
loss = -torch.min(ratio * advantages, clipped * advantages) + beta * kl
130+
```
131+
132+
This gives us full control while still eliminating the infrastructure/glue code
133+
by using HF Trainer for the training loop.
134+
135+
## Recommendation
136+
137+
1. **Merge PR #73** (openadapt-evals RL environment) — stable foundation, CI passing
138+
2. **Don't merge PR #34 as-is** — the custom trainer has too many issues
139+
3. **Rewrite GRPO module** as thin TRL adapter using `rollout_func`:
140+
- Keep: rollout_collector.py, reward.py, config.py, cot_warmup.py
141+
- Replace: trainer.py (800 lines -> ~200 lines)
142+
- Delete: All custom GRPO math, model loading, optimizer, checkpointing
143+
4. **Close ~14 GitHub issues** that become N/A with TRL delegation
144+
145+
## TRL Version Compatibility Note
146+
147+
TRL v0.29.0 `rollout_func` requires `transformers>=5.2.0`. Verify this works with
148+
Unsloth and our quantization setup before committing to this path.
149+
150+
## References
151+
152+
- [TRL GRPOTrainer docs](https://huggingface.co/docs/trl/main/en/grpo_trainer)
153+
- [TRL OpenEnv integration](https://huggingface.co/docs/trl/main/en/openenv)
154+
- [TRL v0.29.0 release](https://github.com/huggingface/trl/releases/tag/v0.29.0)
155+
- [WebAgent-R1 paper](https://arxiv.org/abs/2505.16421)
156+
- [Agent-R1 (veRL-based)](https://github.com/0russwest0/Agent-R1)
157+
- GitHub issues: openadapt-ml #35-#50, #42 (tracking)
158+
- GitHub issues: openadapt-evals #76-#78

docs/grpo_e2e_test_design.md

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# GRPO E2E Test Design
2+
3+
## Date: 2026-03-02
4+
5+
## Problem
6+
7+
The GRPO trainer was recently rewritten. The existing tests in `tests/test_grpo.py` are
8+
unit tests that mock everything and only verify individual functions in isolation. We need
9+
end-to-end tests that exercise the full training loop and produce artifacts a human can
10+
inspect to verify correctness.
11+
12+
## What a human reviewer needs to see
13+
14+
1. **Did the training loop run without errors?** -- test report with pass/fail, duration,
15+
error traces.
16+
2. **Did the model weights change?** -- LoRA parameter diff (L2 norm of delta) before vs
17+
after training. If weights did not change, training is broken.
18+
3. **Were rollouts collected and rewards computed?** -- rollout traces showing the sequence
19+
of (screenshot, action, reward) for each rollout.
20+
4. **Is the loss signal reasonable?** -- per-step metrics: loss, reward_mean,
21+
advantage stats, gradient norm.
22+
5. **Can the checkpoint be saved and reloaded?** -- verify the saved LoRA adapter can be
23+
loaded back.
24+
6. **Does the GRPO loss function actually drive policy toward high-reward actions?** --
25+
synthetic convergence test with controlled log-probs and rewards.
26+
27+
## Design Options Considered
28+
29+
### Option A: pytest with artifact directory
30+
- Standard pytest tests write artifacts to `test_artifacts/grpo_e2e/`.
31+
- Pros: CI integration, no extra dependencies, familiar.
32+
- Cons: artifacts are just files on disk; need separate step to view.
33+
34+
### Option B: Standalone script
35+
- `scripts/run_e2e_test.py` with HTML report.
36+
- Pros: rich output, self-contained.
37+
- Cons: does not integrate with CI.
38+
39+
### Option C: pytest + HTML report plugin (pytest-html)
40+
- Best of both worlds but adds a dependency.
41+
42+
### Option D: pytest + artifact directory + separate summary script
43+
- pytest writes artifacts; `scripts/grpo_e2e_report.py` reads them and prints a
44+
formatted summary (or generates HTML).
45+
- Pros: separation of concerns, can re-run report without re-running tests, CI-friendly.
46+
- Cons: two invocations.
47+
48+
### Chosen: Option D
49+
50+
Reasoning:
51+
- The user wants to "look at" results -- a summary script can print a clean, readable
52+
report without adding pytest-html as a dependency.
53+
- Tests work in CI (pytest) and locally (run report script after).
54+
- Artifacts tell the full story: JSON metrics, PNG screenshots, rollout traces.
55+
- Report script can be extended later to generate HTML without changing tests.
56+
57+
## Test Architecture
58+
59+
### Mock Strategy
60+
61+
We do NOT load a real Qwen2.5-VL model (too slow, too large). Instead:
62+
63+
1. **Mock model**: A tiny `nn.Module` with a single linear layer + LoRA-like trainable
64+
params. It accepts "input_ids" and returns logits. This lets us test that gradients
65+
flow and weights update without needing a 7B model.
66+
2. **Mock processor**: Returns pre-built tensors. Has `apply_chat_template`,
67+
`decode`, and `__call__` methods.
68+
3. **Mock environment**: Generates synthetic screenshots (colored rectangles with text
69+
via PIL), returns mock `RolloutStep` objects with realistic `BenchmarkObservation`
70+
and `BenchmarkAction` data. Reward is deterministic based on the action.
71+
4. **Mock rollout collector**: Replaces `GRPORolloutCollector` -- returns pre-built
72+
`Rollout` objects with mock steps that contain PNG screenshot bytes.
73+
74+
This way:
75+
- The training loop (optimizer, loss computation, checkpointing) is exercised for real.
76+
- Artifacts contain visually meaningful screenshots.
77+
- Tests run in < 60s on CPU.
78+
79+
### Tests
80+
81+
1. **`test_e2e_training_loop_mock`** -- Full loop: 2 training steps, 2 rollouts each.
82+
Verifies weights change, loss is computed, checkpoint is saved and loadable.
83+
84+
2. **`test_e2e_rollout_collection_mock`** -- Collects rollouts from mock environment,
85+
saves traces (JSON) and screenshots (PNG) as artifacts.
86+
87+
3. **`test_e2e_grpo_loss_convergence`** -- Synthetic test: creates fake log-probs
88+
(as trainable parameters) and rewards, runs GRPO loss + optimizer for 50 steps,
89+
verifies the "policy" shifts probability toward high-reward actions.
90+
91+
### Artifacts Written
92+
93+
```
94+
test_artifacts/grpo_e2e/<timestamp>/
95+
test_report.json -- overall pass/fail, timing, errors
96+
training_log.json -- per-step metrics from the training loop
97+
rollout_traces/
98+
step_0_rollout_0.json -- per-rollout trace
99+
step_0_rollout_0_screenshot_0.png
100+
...
101+
model_diff.json -- LoRA weight delta stats
102+
checkpoint/ -- saved LoRA adapter
103+
convergence/
104+
loss_history.json -- loss values over 50 synthetic steps
105+
advantage_policy.json -- policy probabilities over time
106+
summary.txt -- human-readable summary
107+
```
108+
109+
### Report Script
110+
111+
`scripts/grpo_e2e_report.py` reads the artifact directory and prints:
112+
- Test status (pass/fail per test)
113+
- Training metrics summary
114+
- Model weight change (did LoRA params move?)
115+
- Convergence check (did loss decrease in synthetic test?)
116+
- File listing of all artifacts
117+
118+
Uses `fire` for CLI: `python scripts/grpo_e2e_report.py <artifact_dir>`

0 commit comments

Comments
 (0)