feat: VLMModelWrapper — multimodal compatibility layer for TRL#251
Merged
Conversation
TRL's GRPOTrainer calls model.forward(input_ids=...) during training without pixel_values. VLMs need pixel_values to produce meaningful logits. Without them, the model is blind and generates garbage. VLMModelWrapper caches vision tensors during rollout generation (when we have the images) and injects them during TRL's forward pass. This is the standard adapter pattern — 120 lines, no TRL internals modified. - vlm_wrapper.py: VLMModelWrapper with cache_vision_inputs + forward - trl_wrapper.py: wraps model before passing to GRPOTrainer - trl_rollout.py: calls cache_vision_inputs before model.generate - 9 tests covering injection, delegation, cache behavior, warnings Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
5 e2e tests (@pytest.mark.heavy, CPU-only, skipped in CI): - test_generation_sees_pixel_values: model not blind during rollout - test_trl_forward_gets_cached_pixel_values: wrapper injects into TRL - test_output_format_not_garbage: prompt has DSL format guidance - test_no_thinking_tokens_in_template: no <think> in chat template - test_vision_changes_logits: pixel_values actually affect logits 2 integration tests (light, runs in CI): - test_wrapper_used_in_train_source: VLMModelWrapper in trl_wrapper - test_generate_fn_calls_cache_vision_inputs: cache call in rollout Each test maps to a bug class from the March 29 session. Together they prevent the entire class of multimodal TRL failures before they reach the customer. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Root cause of persistent garbage output: TRL's training forward pass calls
model.forward(input_ids=...)withoutpixel_values. The VLM is blind during logprob recomputation, producing garbage logits.Fix:
VLMModelWrappercaches vision tensors during rollout generation and injects them during TRL's forward pass. Standard adapter pattern — 120 lines, no TRL internals modified.How it works
Test plan
🤖 Generated with Claude Code