Skip to content

feat: VLMModelWrapper — multimodal compatibility layer for TRL#251

Merged
abrichr merged 2 commits into
mainfrom
feat/vlm-model-wrapper-for-trl
Mar 29, 2026
Merged

feat: VLMModelWrapper — multimodal compatibility layer for TRL#251
abrichr merged 2 commits into
mainfrom
feat/vlm-model-wrapper-for-trl

Conversation

@abrichr
Copy link
Copy Markdown
Member

@abrichr abrichr commented Mar 29, 2026

Summary

Root cause of persistent garbage output: TRL's training forward pass calls model.forward(input_ids=...) without pixel_values. The VLM is blind during logprob recomputation, producing garbage logits.

Fix: VLMModelWrapper caches vision tensors during rollout generation and injects them during TRL's forward pass. Standard adapter pattern — 120 lines, no TRL internals modified.

How it works

Rollout (our code):
  inputs = processor(text=..., images=[img])  → has pixel_values
  wrapper.cache_vision_inputs(inputs)          → cached
  wrapper.generate(**inputs)                   → model sees image ✓

Training step (TRL's code):
  wrapper.forward(input_ids=...)               → pixel_values injected from cache
  → model sees image ✓ (was blind before)

Test plan

  • 9 wrapper tests: injection, delegation, cache, warnings
  • 32 existing TRL tests pass
  • Client GPU test — should see DSL output + non-zero gradients

🤖 Generated with Claude Code

abrichr and others added 2 commits March 29, 2026 18:53
TRL's GRPOTrainer calls model.forward(input_ids=...) during training
without pixel_values. VLMs need pixel_values to produce meaningful
logits. Without them, the model is blind and generates garbage.

VLMModelWrapper caches vision tensors during rollout generation (when
we have the images) and injects them during TRL's forward pass. This
is the standard adapter pattern — 120 lines, no TRL internals modified.

- vlm_wrapper.py: VLMModelWrapper with cache_vision_inputs + forward
- trl_wrapper.py: wraps model before passing to GRPOTrainer
- trl_rollout.py: calls cache_vision_inputs before model.generate
- 9 tests covering injection, delegation, cache behavior, warnings

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
5 e2e tests (@pytest.mark.heavy, CPU-only, skipped in CI):
- test_generation_sees_pixel_values: model not blind during rollout
- test_trl_forward_gets_cached_pixel_values: wrapper injects into TRL
- test_output_format_not_garbage: prompt has DSL format guidance
- test_no_thinking_tokens_in_template: no <think> in chat template
- test_vision_changes_logits: pixel_values actually affect logits

2 integration tests (light, runs in CI):
- test_wrapper_used_in_train_source: VLMModelWrapper in trl_wrapper
- test_generate_fn_calls_cache_vision_inputs: cache call in rollout

Each test maps to a bug class from the March 29 session. Together they
prevent the entire class of multimodal TRL failures before they reach
the customer.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@abrichr abrichr merged commit fa26d55 into main Mar 29, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant