Commit 339e5d3
feat: add GRPO training module with minimal TRL bridge (#34)
* docs: add experimental roadmap and evidence context to vision
- Add 2x2 experimental matrix (retrieval × fine-tuning) to Core Thesis
- Add evidence context to benchmark table: note it's an internal synthetic
benchmark (~3 UI elements) that validates the pipeline, not real-world
performance. Link to openadapt-evals for ongoing WAA/OSWorld evaluation.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* fix: use 46.7% consistently in 2x2 matrix
Was showing 33-47% range which conflated preliminary (n=3) and full
(n=45) results. The validated number is 46.7%.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* feat: add GRPO training module for online RL
Add openadapt_ml/training/grpo/ package with:
- GRPOConfig for training hyperparameters
- GRPORolloutCollector connecting to openadapt-evals RLEnvironment
- GRPOTrainer implementing custom GRPO loop for multimodal VLMs
- Binary reward function and group-relative advantage computation
- Chain-of-thought warm-up pipeline for SFT pre-training
- 20 unit tests passing without GPU
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* fix: address review findings in GRPO module
- Replace copy.deepcopy(model) with LoRA state dict snapshot (prevents OOM)
- Mark _compute_rollout_loss as scaffold with dummy forward pass for grad flow
- Fix collect_rollout call to match RLEnvironment API (task_id in signature)
- Add model.eval()/model.train() toggling around rollout/training phases
- Remove unused gradient_accumulation_steps config field
- Use actual screen_size from RLEnvironment instead of hardcoded 1920x1200
- Clamp CLICK coordinates to [0.0, 1.0] to prevent invalid pixel values
- Validate task_ids non-empty at start of train()
- Export CoT warmup functions from package __init__
- Add BenchmarkAction fallback when openadapt-evals not installed
- Add 9 new tests: action parser (8) + empty task_ids validation (1)
- All 29 tests passing
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* feat: implement GRPO loss computation and fix cot_warmup dependency
Implement the core _compute_rollout_loss method that was previously a
NotImplementedError scaffold. The implementation:
- Reconstructs VLM prompts from rollout observations
- Formats actions back to DSL text via new _format_action_as_text helper
- Computes log-probabilities of action tokens under current policy
- Computes reference policy log-probs via PEFT disable_adapter() with
fallback to manual LoRA weight swapping
- Returns GRPO loss: -advantage * log_prob + kl_coef * KL penalty
Also adds get_api_adapter() factory function to api_adapter.py, fixing
the broken import in cot_warmup.py's generate_cot_annotations().
Additional review fixes from prior session:
- Initialize _is_unsloth and _ref_lora_state in __init__
- Remove dead else branch for task_id selection
- Fix total_loss device placement
- LoRA-only fallback save in checkpoint
- TYPE regex accepts single quotes
- Coordinate clamping in _parse_vlm_output_to_action
40 tests passing (10 new: 8 format_action + 1 roundtrip + 1 api_adapter).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* refactor: deduplicate GRPO prompts via shared _build_agent_messages
Extract prompt construction into _build_agent_messages() which imports
SYSTEM_PROMPT from next_action.py (the SFT training prompt). This
ensures the GRPO agent uses the same prompt distribution the model was
warm-started on, and guarantees _make_agent_fn and _compute_rollout_loss
use identical prompts (critical for correct log-prob computation).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* fix(grpo): address critical review findings in GRPO loss computation
- C-01: Store raw model output on action._grpo_raw_text for accurate loss
- C-02: Separate tokenization of prompt/action with concatenation to fix
BPE boundary alignment
- I-01: Prefer LoRA weight swapping over disable_adapter() for reference
policy (captures initial LoRA state after SFT warm-start)
- I-03: Per-step gradient accumulation via immediate backward() to prevent
OOM from building computation graph over all rollout steps
- I-04: Fix unescape order in TYPE parser (backslash before quotes)
- M-03: Pass model_name through get_api_adapter to ApiVLMAdapter
- M-07: Case-insensitive CLICK/TYPE regex in _parse_vlm_output_to_action
- L-01: Extract DEFAULT_SCREEN_SIZE constant, replace all hardcoded values
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* fix(grpo): fix instruction propagation, screen size, weight swap safety
- CR-01: Task instruction was never populated during GRPO rollouts.
WAALiveAdapter._get_observation() does not populate raw_observation,
so the agent prompt said "Goal: " with nothing after it. Fix: store
instruction on Rollout dataclass (populated from env._current_task
in collector), use it in both agent_fn and _compute_rollout_loss.
- IM-01: Change DEFAULT_SCREEN_SIZE from 1920x1200 to 1920x1080 for
consistency with baselines module and standard VM configurations.
Add screen_size field to GRPOConfig so it is configurable.
- IM-02: Add try/finally around LoRA weight swap in
_compute_ref_log_probs. Without this, an exception during the
reference forward pass permanently corrupts the model state.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* fix(grpo): remove unused torch import in _setup_model
The import torch at line 121 was flagged by ruff (F401) as unused.
The surrounding code only calls .detach().clone() on tensor objects,
which does not require the torch module directly.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* style(grpo): apply ruff formatting to GRPO module files
Run ruff format on cot_warmup.py, rollout_collector.py, and trainer.py
to satisfy the CI ruff formatter check.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* refactor(grpo): replace custom trainer with minimal TRL bridge
Replace 809-line custom GRPO trainer with ~280 lines that:
- Use standard HuggingFace AutoModelForVision2Seq + AutoProcessor + PEFT
LoraConfig instead of Unsloth monkey-patching
- Implement standalone GRPO loss in ~15 lines of PyTorch (clipped
surrogate) instead of custom policy gradient + KL penalty
- Use beta=0.0 (no KL penalty, no reference model) per DAPO/Open-
Reasoner-Zero literature, eliminating weight-swap complexity
- Keep per-step backward to avoid OOM on long trajectories
- Use standard model.save_pretrained() for checkpointing
- Document WHY standalone GRPO math vs TRL GRPOTrainer (VLM multi-turn
image pixel_values not stored in token IDs) and WHEN to switch
Preserves all public API: GRPOTrainer, _parse_vlm_output_to_action,
_format_action_as_text, _build_agent_messages, DEFAULT_SCREEN_SIZE.
All 50 tests pass (44 existing + 6 new for grpo_loss and trainer internals).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* feat(grpo): add E2E tests with artifact generation and architecture docs
- tests/test_grpo_e2e.py: 5 E2E tests (training loop, rollout collection,
loss convergence, weight diff, mathematical properties) using tiny mock
VLM. Produces 65+ artifacts (JSON traces, PNGs, checkpoints, summaries).
- scripts/grpo_e2e_report.py: CLI report generator for test artifacts
(text + optional HTML output).
- docs/grpo_e2e_test_design.md: design rationale for E2E test approach
- docs/grpo_architecture_analysis.md: analysis of custom vs TRL-based GRPO
- docs/grpo_trl_rewrite_draft.py: TRL v0.29.0 integration research
- docs/strategic_analysis_evals_ml_synergy.md: business/economics analysis
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* fix(grpo): address self-review findings (BUG-01, CLEAN-01 through -05)
- Rename grpo_loss to policy_gradient_loss with honest docstring: single-epoch
on-policy means ratio=1.0, clipping never fires, this is REINFORCE with
group-relative advantages. Keep grpo_loss as backwards-compatible alias.
- Add public aliases: parse_vlm_output_to_action, format_action_as_text
(drop underscore prefix for public API)
- Export policy_gradient_loss and public functions from __init__.py
- Remove unused config fields: kl_coef (was 0.01 but never used with beta=0),
max_seq_length (never referenced)
- Fix model_name default: Qwen/Qwen2.5-VL-7B-Instruct (not unsloth variant)
- Fix trivial test assertion: grad_norm > 0 (was >= 0, always true)
- Update loss tests to verify gradient direction, not just loss sign
- Add test_public_api_exports for new public names
56 tests pass (51 unit + 5 E2E).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>1 parent 70a0c49 commit 339e5d3
15 files changed
Lines changed: 5474 additions & 2 deletions
File tree
- docs
- openadapt_ml
- models
- training/grpo
- scripts
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
0 commit comments