Skip to content

Commit fa26d55

Browse files
abrichrclaude
andauthored
feat: VLMModelWrapper — multimodal compatibility layer for TRL (#251)
* feat: VLMModelWrapper — multimodal compatibility layer for TRL TRL's GRPOTrainer calls model.forward(input_ids=...) during training without pixel_values. VLMs need pixel_values to produce meaningful logits. Without them, the model is blind and generates garbage. VLMModelWrapper caches vision tensors during rollout generation (when we have the images) and injects them during TRL's forward pass. This is the standard adapter pattern — 120 lines, no TRL internals modified. - vlm_wrapper.py: VLMModelWrapper with cache_vision_inputs + forward - trl_wrapper.py: wraps model before passing to GRPOTrainer - trl_rollout.py: calls cache_vision_inputs before model.generate - 9 tests covering injection, delegation, cache behavior, warnings Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * test: add e2e tests for VLM+TRL pipeline and wrapper integration 5 e2e tests (@pytest.mark.heavy, CPU-only, skipped in CI): - test_generation_sees_pixel_values: model not blind during rollout - test_trl_forward_gets_cached_pixel_values: wrapper injects into TRL - test_output_format_not_garbage: prompt has DSL format guidance - test_no_thinking_tokens_in_template: no <think> in chat template - test_vision_changes_logits: pixel_values actually affect logits 2 integration tests (light, runs in CI): - test_wrapper_used_in_train_source: VLMModelWrapper in trl_wrapper - test_generate_fn_calls_cache_vision_inputs: cache call in rollout Each test maps to a bug class from the March 29 session. Together they prevent the entire class of multimodal TRL failures before they reach the customer. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 93fa395 commit fa26d55

6 files changed

Lines changed: 752 additions & 1 deletion

File tree

openadapt_evals/training/trl_rollout.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,11 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
643643
return_tensors="pt", padding=True,
644644
).to(device)
645645

646+
# Cache vision inputs so the VLMModelWrapper can inject
647+
# pixel_values during TRL's training forward pass.
648+
if hasattr(model, "cache_vision_inputs"):
649+
model.cache_vision_inputs(inputs)
650+
646651
with torch.no_grad():
647652
outputs = model.generate(
648653
**inputs,

openadapt_evals/training/trl_wrapper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,14 @@ def train(self) -> str:
147147
if self._on_model_loaded:
148148
self._on_model_loaded(model, processor)
149149

150+
# --- Wrap model for TRL multimodal compatibility ---
151+
# TRL's GRPOTrainer calls model.forward(input_ids=...) during the
152+
# training step without pixel_values. VLMs need pixel_values to
153+
# produce meaningful logits. The wrapper caches vision inputs from
154+
# our rollout generation and injects them into TRL's forward pass.
155+
from openadapt_evals.training.vlm_wrapper import VLMModelWrapper
156+
vlm_wrapper = VLMModelWrapper(model)
157+
150158
# --- Rollout function (from our config) ---
151159
from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig
152160
adapter = WAALiveAdapter(WAALiveConfig(
@@ -260,7 +268,7 @@ def on_step_end(self, args, state, control, **kwargs):
260268

261269
# --- Train ---
262270
trainer = _TRLTrainer(
263-
model=model,
271+
model=vlm_wrapper,
264272
processing_class=processor,
265273
args=trl_config,
266274
train_dataset=dataset,
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""VLM model wrapper for TRL compatibility.
2+
3+
TRL's GRPOTrainer was designed for text-only LLMs. During the training
4+
step, it calls model.forward(input_ids=...) to recompute logprobs under
5+
the current policy. For multimodal VLMs, this forward pass also needs
6+
pixel_values and image_grid_thw — but TRL doesn't know about them.
7+
8+
This wrapper solves the problem by caching vision inputs during rollout
9+
generation (when we have the images) and injecting them during TRL's
10+
forward pass (when TRL only passes input_ids).
11+
12+
Usage:
13+
from openadapt_evals.training.vlm_wrapper import VLMModelWrapper
14+
15+
wrapper = VLMModelWrapper(model)
16+
trainer = GRPOTrainer(model=wrapper, ...)
17+
18+
# During rollout generation:
19+
inputs = processor(text=..., images=[img], return_tensors="pt")
20+
wrapper.cache_vision_inputs(inputs)
21+
outputs = wrapper.generate(**inputs, ...)
22+
23+
# During TRL's training forward pass:
24+
# TRL calls wrapper.forward(input_ids=...) — we inject cached vision inputs
25+
"""
26+
27+
from __future__ import annotations
28+
29+
import logging
30+
from typing import Any
31+
32+
logger = logging.getLogger(__name__)
33+
34+
35+
class VLMModelWrapper:
36+
"""Wraps a VLM so TRL's forward pass gets pixel_values.
37+
38+
Caches vision tensors (pixel_values, image_grid_thw) during rollout
39+
generation and injects them during forward passes that lack them.
40+
41+
This is the standard adapter pattern for making framework-incompatible
42+
models work with training frameworks. TRL calls model.forward() with
43+
only input_ids; we intercept and add the vision inputs.
44+
"""
45+
46+
def __init__(self, model: Any):
47+
# Store model WITHOUT going through __setattr__ (which delegates to model)
48+
object.__setattr__(self, "_vlm_model", model)
49+
object.__setattr__(self, "_vision_cache", None)
50+
object.__setattr__(self, "_cache_hits", 0)
51+
object.__setattr__(self, "_cache_misses", 0)
52+
53+
def cache_vision_inputs(self, inputs: dict[str, Any]) -> None:
54+
"""Cache vision tensors from a processor output dict.
55+
56+
Call this during rollout generation, right after processor() and
57+
before generate(). The cached tensors will be injected into
58+
subsequent forward() calls that lack pixel_values.
59+
60+
Args:
61+
inputs: Dict from processor(text=..., images=...) containing
62+
pixel_values and optionally image_grid_thw.
63+
"""
64+
cache = {}
65+
for key in ("pixel_values", "image_grid_thw"):
66+
if key in inputs:
67+
# Clone and detach to avoid gradient issues
68+
val = inputs[key]
69+
if hasattr(val, "detach"):
70+
cache[key] = val.detach().clone()
71+
else:
72+
cache[key] = val
73+
if cache:
74+
object.__setattr__(self, "_vision_cache", cache)
75+
76+
def forward(self, input_ids: Any = None, **kwargs: Any) -> Any:
77+
"""Forward pass with automatic vision input injection.
78+
79+
If kwargs lacks pixel_values and we have cached vision inputs,
80+
inject them. This is the key fix: TRL calls model.forward()
81+
with only input_ids, but VLMs need pixel_values too.
82+
"""
83+
model = object.__getattribute__(self, "_vlm_model")
84+
cache = object.__getattribute__(self, "_vision_cache")
85+
86+
if "pixel_values" not in kwargs and cache is not None:
87+
for key, val in cache.items():
88+
if key not in kwargs:
89+
# Move to same device as input_ids
90+
if hasattr(val, "to") and hasattr(input_ids, "device"):
91+
kwargs[key] = val.to(input_ids.device)
92+
else:
93+
kwargs[key] = val
94+
hits = object.__getattribute__(self, "_cache_hits")
95+
object.__setattr__(self, "_cache_hits", hits + 1)
96+
if hits == 0:
97+
logger.info(
98+
"VLMModelWrapper: injecting cached vision inputs into "
99+
"forward pass (keys=%s). This means TRL called forward() "
100+
"without pixel_values — the wrapper is working as intended.",
101+
list(cache.keys()),
102+
)
103+
elif "pixel_values" not in kwargs and cache is None:
104+
misses = object.__getattribute__(self, "_cache_misses")
105+
object.__setattr__(self, "_cache_misses", misses + 1)
106+
if misses == 0:
107+
logger.warning(
108+
"VLMModelWrapper: forward() called without pixel_values "
109+
"and no cached vision inputs available. The model is blind. "
110+
"Ensure cache_vision_inputs() is called during generation.",
111+
)
112+
113+
return model(input_ids=input_ids, **kwargs)
114+
115+
def generate(self, **kwargs: Any) -> Any:
116+
"""Generate with the underlying model. No interception needed —
117+
our generate_fn passes pixel_values explicitly."""
118+
model = object.__getattribute__(self, "_vlm_model")
119+
return model.generate(**kwargs)
120+
121+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
122+
"""Route __call__ to forward for compatibility with TRL."""
123+
return self.forward(*args, **kwargs)
124+
125+
def __getattr__(self, name: str) -> Any:
126+
"""Delegate all other attribute access to the wrapped model.
127+
128+
This makes the wrapper transparent: trainer.model.config,
129+
trainer.model.parameters(), etc. all work as expected.
130+
"""
131+
model = object.__getattribute__(self, "_vlm_model")
132+
return getattr(model, name)
133+
134+
def __setattr__(self, name: str, value: Any) -> None:
135+
"""Delegate attribute setting to the wrapped model."""
136+
model = object.__getattribute__(self, "_vlm_model")
137+
setattr(model, name, value)

0 commit comments

Comments
 (0)