Skip to content

Commit 0f381b1

Browse files
abrichrclaude
andauthored
fix: patch model.forward() directly instead of wrapper class (#253)
TRL unwraps models via Accelerate, stripping wrapper classes. The fix: patch forward() on the model instance itself. This survives unwrapping. - patch_model_for_trl(model) → returns cache_fn - cache_fn(inputs) caches pixel_values from processor output - Patched forward() injects cached pixel_values when TRL omits them - Patched __call__ also injects (covers all call paths) - trl_wrapper passes original model to TRL (not a wrapper) - cache_vision_fn passed through to rollout_func Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e27dafc commit 0f381b1

File tree

5 files changed

+269
-289
lines changed

5 files changed

+269
-289
lines changed

openadapt_evals/training/trl_rollout.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def make_waa_rollout_func(
367367
stuck_threshold: int = 3,
368368
on_before_collect: Optional[Callable] = None,
369369
on_rollout_complete: Optional[Callable] = None,
370+
cache_vision_fn: Optional[Callable] = None,
370371
) -> Callable:
371372
"""Create a TRL-compatible rollout_func for WAA environments.
372373
@@ -645,7 +646,11 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
645646

646647
# Cache vision inputs so the VLMModelWrapper can inject
647648
# pixel_values during TRL's training forward pass.
648-
if hasattr(model, "cache_vision_inputs"):
649+
# Cache vision inputs so the patched forward() can inject
650+
# pixel_values during TRL's training step and generate() calls.
651+
if cache_vision_fn is not None:
652+
cache_vision_fn(dict(inputs.items()) if hasattr(inputs, "items") else inputs)
653+
elif hasattr(model, "cache_vision_inputs"):
649654
model.cache_vision_inputs(inputs)
650655

651656
with torch.no_grad():

openadapt_evals/training/trl_wrapper.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,14 @@ def train(self) -> str:
147147
if self._on_model_loaded:
148148
self._on_model_loaded(model, processor)
149149

150-
# --- Wrap model for TRL multimodal compatibility ---
151-
# TRL's GRPOTrainer calls model.forward(input_ids=...) during the
152-
# training step without pixel_values. VLMs need pixel_values to
153-
# produce meaningful logits. The wrapper caches vision inputs from
154-
# our rollout generation and injects them into TRL's forward pass.
155-
from openadapt_evals.training.vlm_wrapper import VLMModelWrapper
156-
vlm_wrapper = VLMModelWrapper(model)
150+
# --- Patch model for TRL multimodal compatibility ---
151+
# TRL's GRPOTrainer calls model.forward(input_ids=...) and
152+
# model.generate(input_ids=...) without pixel_values. VLMs need
153+
# pixel_values. We patch the model's forward() directly on the
154+
# instance so it survives TRL/Accelerate unwrapping (which strips
155+
# wrapper classes). The cache_fn is passed to rollout_func.
156+
from openadapt_evals.training.vlm_wrapper import patch_model_for_trl
157+
cache_vision_fn = patch_model_for_trl(model)
157158

158159
# --- Rollout function (from our config) ---
159160
from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig
@@ -176,6 +177,7 @@ def train(self) -> str:
176177
temperature=self._config.temperature,
177178
on_before_collect=self._on_before_collect,
178179
on_rollout_complete=self._on_rollout_complete,
180+
cache_vision_fn=cache_vision_fn,
179181
)
180182

181183
# --- Reward ---
@@ -268,7 +270,7 @@ def on_step_end(self, args, state, control, **kwargs):
268270

269271
# --- Train ---
270272
trainer = _TRLTrainer(
271-
model=vlm_wrapper,
273+
model=model,
272274
processing_class=processor,
273275
args=trl_config,
274276
train_dataset=dataset,
Lines changed: 122 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,157 @@
1-
"""VLM model wrapper for TRL compatibility.
1+
"""VLM model patching for TRL compatibility.
22
3-
TRL's GRPOTrainer was designed for text-only LLMs. During the training
4-
step, it calls model.forward(input_ids=...) to recompute logprobs under
5-
the current policy. For multimodal VLMs, this forward pass also needs
6-
pixel_values and image_grid_thw — but TRL doesn't know about them.
3+
TRL's GRPOTrainer was designed for text-only LLMs. It unwraps models
4+
via Accelerate, which strips any external wrapper class. The fix:
5+
patch the model's forward() method directly on the instance. This
6+
survives unwrapping because it's on the model object, not a wrapper.
77
8-
This wrapper solves the problem by caching vision inputs during rollout
9-
generation (when we have the images) and injecting them during TRL's
10-
forward pass (when TRL only passes input_ids).
8+
Two functions:
9+
- ``patch_model_for_trl(model)``: patches model.forward to inject
10+
cached pixel_values. Returns a ``cache_vision_inputs`` callable.
11+
- ``VLMModelWrapper``: legacy wrapper class (kept for backward compat,
12+
delegates to patch_model_for_trl internally).
1113
1214
Usage:
13-
from openadapt_evals.training.vlm_wrapper import VLMModelWrapper
15+
from openadapt_evals.training.vlm_wrapper import patch_model_for_trl
1416
15-
wrapper = VLMModelWrapper(model)
16-
trainer = GRPOTrainer(model=wrapper, ...)
17+
cache_fn = patch_model_for_trl(model)
1718
1819
# During rollout generation:
1920
inputs = processor(text=..., images=[img], return_tensors="pt")
20-
wrapper.cache_vision_inputs(inputs)
21-
outputs = wrapper.generate(**inputs, ...)
21+
cache_fn(inputs) # cache pixel_values
22+
outputs = model.generate(**inputs, ...) # model sees image ✓
2223
2324
# During TRL's training forward pass:
24-
# TRL calls wrapper.forward(input_ids=...) — we inject cached vision inputs
25+
# TRL calls model.forward(input_ids=...) → patched forward injects
26+
# cached pixel_values automatically. Model sees image ✓
2527
"""
2628

2729
from __future__ import annotations
2830

2931
import logging
30-
from typing import Any
32+
from typing import Any, Callable
3133

3234
logger = logging.getLogger(__name__)
3335

3436

35-
class VLMModelWrapper:
36-
"""Wraps a VLM so TRL's forward pass gets pixel_values.
37+
def patch_model_for_trl(model: Any) -> Callable[[dict[str, Any]], None]:
38+
"""Patch a VLM's forward() to auto-inject cached vision inputs.
39+
40+
This patches the model instance directly (not a wrapper class),
41+
so it survives TRL/Accelerate unwrapping.
3742
38-
Caches vision tensors (pixel_values, image_grid_thw) during rollout
39-
generation and injects them during forward passes that lack them.
43+
Args:
44+
model: A HuggingFace VLM (may be a PeftModel).
4045
41-
This is the standard adapter pattern for making framework-incompatible
42-
models work with training frameworks. TRL calls model.forward() with
43-
only input_ids; we intercept and add the vision inputs.
46+
Returns:
47+
A ``cache_vision_inputs(inputs_dict)`` function. Call this during
48+
rollout generation to cache pixel_values for the training forward.
49+
"""
50+
# Mutable state shared between cache_fn and patched forward
51+
_cache: dict[str, Any] = {}
52+
_logged_inject = [False]
53+
_logged_miss = [False]
54+
55+
original_forward = model.forward
56+
57+
def _patched_forward(input_ids: Any = None, **kwargs: Any) -> Any:
58+
"""Forward with automatic vision input injection."""
59+
if "pixel_values" not in kwargs and _cache:
60+
for key, val in _cache.items():
61+
if key not in kwargs:
62+
if hasattr(val, "to") and hasattr(input_ids, "device"):
63+
kwargs[key] = val.to(input_ids.device)
64+
else:
65+
kwargs[key] = val
66+
if not _logged_inject[0]:
67+
_logged_inject[0] = True
68+
logger.info(
69+
"VLM forward patch: injecting cached vision inputs "
70+
"(keys=%s). TRL called forward() without pixel_values.",
71+
list(_cache.keys()),
72+
)
73+
elif "pixel_values" not in kwargs and not _cache:
74+
if not _logged_miss[0]:
75+
_logged_miss[0] = True
76+
logger.warning(
77+
"VLM forward patch: forward() called without pixel_values "
78+
"and no cache. Model is blind. Call cache_fn() first.",
79+
)
80+
return original_forward(input_ids=input_ids, **kwargs)
81+
82+
# Patch the model instance
83+
model.forward = _patched_forward
84+
85+
# Also patch __call__ if it routes to forward (most HF models do)
86+
# This ensures model(input_ids=...) also gets the injection.
87+
original_call = model.__class__.__call__
88+
89+
def _patched_call(self_model, *args, **kwargs):
90+
# If called without pixel_values, inject from cache
91+
if "pixel_values" not in kwargs and _cache:
92+
for key, val in _cache.items():
93+
if key not in kwargs:
94+
input_ids = kwargs.get("input_ids", args[0] if args else None)
95+
if hasattr(val, "to") and input_ids is not None and hasattr(input_ids, "device"):
96+
kwargs[key] = val.to(input_ids.device)
97+
else:
98+
kwargs[key] = val
99+
return original_call(self_model, *args, **kwargs)
100+
101+
# Only patch __call__ on the instance, not the class
102+
import types
103+
model.__call__ = types.MethodType(_patched_call, model)
104+
105+
logger.info(
106+
"VLM forward patch installed on %s. Vision inputs will be "
107+
"auto-injected during TRL's forward passes.",
108+
type(model).__name__,
109+
)
110+
111+
def cache_vision_inputs(inputs: dict[str, Any]) -> None:
112+
"""Cache vision tensors for injection into forward passes.
113+
114+
Args:
115+
inputs: Dict from processor(text=..., images=...) or a dict
116+
with pixel_values and optionally image_grid_thw.
117+
"""
118+
_cache.clear()
119+
for key in ("pixel_values", "image_grid_thw"):
120+
if key in inputs:
121+
val = inputs[key]
122+
if hasattr(val, "detach"):
123+
_cache[key] = val.detach().clone()
124+
else:
125+
_cache[key] = val
126+
if _cache:
127+
logger.debug("Cached vision inputs: keys=%s", list(_cache.keys()))
128+
129+
return cache_vision_inputs
130+
131+
132+
class VLMModelWrapper:
133+
"""Legacy wrapper — delegates to patch_model_for_trl internally.
134+
135+
Kept for backward compatibility with existing code that creates
136+
VLMModelWrapper(model). New code should use patch_model_for_trl()
137+
directly and pass the original model to TRL.
44138
"""
45139

46140
def __init__(self, model: Any):
47-
# Store model WITHOUT going through __setattr__ (which delegates to model)
48141
object.__setattr__(self, "_vlm_model", model)
142+
object.__setattr__(self, "_cache_fn", patch_model_for_trl(model))
49143
object.__setattr__(self, "_vision_cache", None)
50144
object.__setattr__(self, "_cache_hits", 0)
51145
object.__setattr__(self, "_cache_misses", 0)
52146

53-
# --- PEFT / quantization compatibility ---
54-
# TRL's validate_quantization_for_training() checks for PEFT via:
55-
# 1. isinstance(model, PeftModel) — fails because wrapper isn't PeftModel
56-
# 2. hasattr(model, "peft_config") — works via our __getattr__
57-
# 3. Checking model.is_quantized / model.quantization_method
58-
#
59-
# The isinstance check is the blocker. We solve it by making the
60-
# wrapper's __class__ inherit from the wrapped model's type, so
61-
# isinstance(wrapper, PeftModel) returns True.
147+
# PEFT isinstance compatibility
62148
try:
63149
from peft import PeftModel
64150
if isinstance(model, PeftModel):
65-
# Create a new class that inherits from BOTH our wrapper
66-
# and the actual model class. This makes isinstance work
67-
# while keeping our forward/generate/cache methods.
68151
combined = type(
69152
"VLMPeftModelWrapper",
70153
(VLMModelWrapper, type(model)),
71154
{
72-
# Ensure our methods take priority (MRO)
73155
"forward": VLMModelWrapper.forward,
74156
"generate": VLMModelWrapper.generate,
75157
"__call__": VLMModelWrapper.__call__,
@@ -78,101 +160,28 @@ def __init__(self, model: Any):
78160
},
79161
)
80162
object.__setattr__(self, "__class__", combined)
81-
logger.info(
82-
"VLMModelWrapper: PEFT isinstance compatibility enabled "
83-
"(wrapped model is %s)", type(model).__name__,
84-
)
85-
except ImportError:
163+
except (ImportError, Exception):
86164
pass
87-
except Exception as exc:
88-
# If dynamic class fails, fall back to attribute-level compat
89-
logger.warning(
90-
"VLMModelWrapper: PEFT isinstance setup failed: %s. "
91-
"Falling back to attribute-level compatibility.", exc,
92-
)
93165

94166
def cache_vision_inputs(self, inputs: dict[str, Any]) -> None:
95-
"""Cache vision tensors from a processor output dict.
96-
97-
Call this during rollout generation, right after processor() and
98-
before generate(). The cached tensors will be injected into
99-
subsequent forward() calls that lack pixel_values.
100-
101-
Args:
102-
inputs: Dict from processor(text=..., images=...) containing
103-
pixel_values and optionally image_grid_thw.
104-
"""
105-
cache = {}
106-
for key in ("pixel_values", "image_grid_thw"):
107-
if key in inputs:
108-
# Clone and detach to avoid gradient issues
109-
val = inputs[key]
110-
if hasattr(val, "detach"):
111-
cache[key] = val.detach().clone()
112-
else:
113-
cache[key] = val
114-
if cache:
115-
object.__setattr__(self, "_vision_cache", cache)
167+
cache_fn = object.__getattribute__(self, "_cache_fn")
168+
cache_fn(inputs)
116169

117170
def forward(self, input_ids: Any = None, **kwargs: Any) -> Any:
118-
"""Forward pass with automatic vision input injection.
119-
120-
If kwargs lacks pixel_values and we have cached vision inputs,
121-
inject them. This is the key fix: TRL calls model.forward()
122-
with only input_ids, but VLMs need pixel_values too.
123-
"""
124171
model = object.__getattribute__(self, "_vlm_model")
125-
cache = object.__getattribute__(self, "_vision_cache")
126-
127-
if "pixel_values" not in kwargs and cache is not None:
128-
for key, val in cache.items():
129-
if key not in kwargs:
130-
# Move to same device as input_ids
131-
if hasattr(val, "to") and hasattr(input_ids, "device"):
132-
kwargs[key] = val.to(input_ids.device)
133-
else:
134-
kwargs[key] = val
135-
hits = object.__getattribute__(self, "_cache_hits")
136-
object.__setattr__(self, "_cache_hits", hits + 1)
137-
if hits == 0:
138-
logger.info(
139-
"VLMModelWrapper: injecting cached vision inputs into "
140-
"forward pass (keys=%s). This means TRL called forward() "
141-
"without pixel_values — the wrapper is working as intended.",
142-
list(cache.keys()),
143-
)
144-
elif "pixel_values" not in kwargs and cache is None:
145-
misses = object.__getattribute__(self, "_cache_misses")
146-
object.__setattr__(self, "_cache_misses", misses + 1)
147-
if misses == 0:
148-
logger.warning(
149-
"VLMModelWrapper: forward() called without pixel_values "
150-
"and no cached vision inputs available. The model is blind. "
151-
"Ensure cache_vision_inputs() is called during generation.",
152-
)
153-
154-
return model(input_ids=input_ids, **kwargs)
172+
return model.forward(input_ids=input_ids, **kwargs)
155173

156174
def generate(self, **kwargs: Any) -> Any:
157-
"""Generate with the underlying model. No interception needed —
158-
our generate_fn passes pixel_values explicitly."""
159175
model = object.__getattribute__(self, "_vlm_model")
160176
return model.generate(**kwargs)
161177

162178
def __call__(self, *args: Any, **kwargs: Any) -> Any:
163-
"""Route __call__ to forward for compatibility with TRL."""
164179
return self.forward(*args, **kwargs)
165180

166181
def __getattr__(self, name: str) -> Any:
167-
"""Delegate all other attribute access to the wrapped model.
168-
169-
This makes the wrapper transparent: trainer.model.config,
170-
trainer.model.parameters(), etc. all work as expected.
171-
"""
172182
model = object.__getattribute__(self, "_vlm_model")
173183
return getattr(model, name)
174184

175185
def __setattr__(self, name: str, value: Any) -> None:
176-
"""Delegate attribute setting to the wrapped model."""
177186
model = object.__getattribute__(self, "_vlm_model")
178187
setattr(model, name, value)

tests/test_trl_integration.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -301,22 +301,19 @@ def test_wrapper_passes_callbacks_to_rollout_func(self):
301301
# ---------------------------------------------------------------------------
302302

303303

304-
class TestVLMModelWrapperIntegration:
305-
"""Verify VLMModelWrapper is wired into the TRL training pipeline."""
304+
class TestVLMPatchIntegration:
305+
"""Verify VLM model patching is wired into the TRL training pipeline."""
306306

307-
def test_wrapper_used_in_train_source(self):
308-
"""trl_wrapper.train() wraps the model in VLMModelWrapper."""
307+
def test_patch_used_in_train_source(self):
308+
"""trl_wrapper.train() patches the model for VLM compatibility."""
309309
import inspect
310310
from openadapt_evals.training import trl_wrapper
311311

312312
source = inspect.getsource(trl_wrapper.GRPOTrainer.train)
313-
assert "VLMModelWrapper" in source, (
314-
"GRPOTrainer.train() must wrap the model in VLMModelWrapper "
315-
"before passing to TRL. Without this, TRL's forward pass "
316-
"won't have pixel_values and the VLM will be blind."
317-
)
318-
assert "vlm_wrapper" in source.lower() or "VLMModelWrapper(model)" in source, (
319-
"train() must create VLMModelWrapper(model) to wrap the model."
313+
assert "patch_model_for_trl" in source, (
314+
"GRPOTrainer.train() must call patch_model_for_trl(model) "
315+
"to patch forward() for pixel_values injection. Without this, "
316+
"TRL's forward pass won't have pixel_values and the VLM will be blind."
320317
)
321318

322319
def test_generate_fn_calls_cache_vision_inputs(self):

0 commit comments

Comments
 (0)