1- """VLM model wrapper for TRL compatibility.
1+ """VLM model patching for TRL compatibility.
22
3- TRL's GRPOTrainer was designed for text-only LLMs. During the training
4- step, it calls model.forward(input_ids=...) to recompute logprobs under
5- the current policy. For multimodal VLMs, this forward pass also needs
6- pixel_values and image_grid_thw — but TRL doesn't know about them .
3+ TRL's GRPOTrainer was designed for text-only LLMs. It unwraps models
4+ via Accelerate, which strips any external wrapper class. The fix:
5+ patch the model's forward() method directly on the instance. This
6+ survives unwrapping because it's on the model object, not a wrapper .
77
8- This wrapper solves the problem by caching vision inputs during rollout
9- generation (when we have the images) and injecting them during TRL's
10- forward pass (when TRL only passes input_ids).
8+ Two functions:
9+ - ``patch_model_for_trl(model)``: patches model.forward to inject
10+ cached pixel_values. Returns a ``cache_vision_inputs`` callable.
11+ - ``VLMModelWrapper``: legacy wrapper class (kept for backward compat,
12+ delegates to patch_model_for_trl internally).
1113
1214Usage:
13- from openadapt_evals.training.vlm_wrapper import VLMModelWrapper
15+ from openadapt_evals.training.vlm_wrapper import patch_model_for_trl
1416
15- wrapper = VLMModelWrapper(model)
16- trainer = GRPOTrainer(model=wrapper, ...)
17+ cache_fn = patch_model_for_trl(model)
1718
1819 # During rollout generation:
1920 inputs = processor(text=..., images=[img], return_tensors="pt")
20- wrapper.cache_vision_inputs (inputs)
21- outputs = wrapper .generate(**inputs, ...)
21+ cache_fn (inputs) # cache pixel_values
22+ outputs = model .generate(**inputs, ...) # model sees image ✓
2223
2324 # During TRL's training forward pass:
24- # TRL calls wrapper.forward(input_ids=...) — we inject cached vision inputs
25+ # TRL calls model.forward(input_ids=...) → patched forward injects
26+ # cached pixel_values automatically. Model sees image ✓
2527"""
2628
2729from __future__ import annotations
2830
2931import logging
30- from typing import Any
32+ from typing import Any , Callable
3133
3234logger = logging .getLogger (__name__ )
3335
3436
35- class VLMModelWrapper :
36- """Wraps a VLM so TRL's forward pass gets pixel_values.
37+ def patch_model_for_trl (model : Any ) -> Callable [[dict [str , Any ]], None ]:
38+ """Patch a VLM's forward() to auto-inject cached vision inputs.
39+
40+ This patches the model instance directly (not a wrapper class),
41+ so it survives TRL/Accelerate unwrapping.
3742
38- Caches vision tensors (pixel_values, image_grid_thw) during rollout
39- generation and injects them during forward passes that lack them .
43+ Args:
44+ model: A HuggingFace VLM (may be a PeftModel) .
4045
41- This is the standard adapter pattern for making framework-incompatible
42- models work with training frameworks. TRL calls model.forward() with
43- only input_ids; we intercept and add the vision inputs.
46+ Returns:
47+ A ``cache_vision_inputs(inputs_dict)`` function. Call this during
48+ rollout generation to cache pixel_values for the training forward.
49+ """
50+ # Mutable state shared between cache_fn and patched forward
51+ _cache : dict [str , Any ] = {}
52+ _logged_inject = [False ]
53+ _logged_miss = [False ]
54+
55+ original_forward = model .forward
56+
57+ def _patched_forward (input_ids : Any = None , ** kwargs : Any ) -> Any :
58+ """Forward with automatic vision input injection."""
59+ if "pixel_values" not in kwargs and _cache :
60+ for key , val in _cache .items ():
61+ if key not in kwargs :
62+ if hasattr (val , "to" ) and hasattr (input_ids , "device" ):
63+ kwargs [key ] = val .to (input_ids .device )
64+ else :
65+ kwargs [key ] = val
66+ if not _logged_inject [0 ]:
67+ _logged_inject [0 ] = True
68+ logger .info (
69+ "VLM forward patch: injecting cached vision inputs "
70+ "(keys=%s). TRL called forward() without pixel_values." ,
71+ list (_cache .keys ()),
72+ )
73+ elif "pixel_values" not in kwargs and not _cache :
74+ if not _logged_miss [0 ]:
75+ _logged_miss [0 ] = True
76+ logger .warning (
77+ "VLM forward patch: forward() called without pixel_values "
78+ "and no cache. Model is blind. Call cache_fn() first." ,
79+ )
80+ return original_forward (input_ids = input_ids , ** kwargs )
81+
82+ # Patch the model instance
83+ model .forward = _patched_forward
84+
85+ # Also patch __call__ if it routes to forward (most HF models do)
86+ # This ensures model(input_ids=...) also gets the injection.
87+ original_call = model .__class__ .__call__
88+
89+ def _patched_call (self_model , * args , ** kwargs ):
90+ # If called without pixel_values, inject from cache
91+ if "pixel_values" not in kwargs and _cache :
92+ for key , val in _cache .items ():
93+ if key not in kwargs :
94+ input_ids = kwargs .get ("input_ids" , args [0 ] if args else None )
95+ if hasattr (val , "to" ) and input_ids is not None and hasattr (input_ids , "device" ):
96+ kwargs [key ] = val .to (input_ids .device )
97+ else :
98+ kwargs [key ] = val
99+ return original_call (self_model , * args , ** kwargs )
100+
101+ # Only patch __call__ on the instance, not the class
102+ import types
103+ model .__call__ = types .MethodType (_patched_call , model )
104+
105+ logger .info (
106+ "VLM forward patch installed on %s. Vision inputs will be "
107+ "auto-injected during TRL's forward passes." ,
108+ type (model ).__name__ ,
109+ )
110+
111+ def cache_vision_inputs (inputs : dict [str , Any ]) -> None :
112+ """Cache vision tensors for injection into forward passes.
113+
114+ Args:
115+ inputs: Dict from processor(text=..., images=...) or a dict
116+ with pixel_values and optionally image_grid_thw.
117+ """
118+ _cache .clear ()
119+ for key in ("pixel_values" , "image_grid_thw" ):
120+ if key in inputs :
121+ val = inputs [key ]
122+ if hasattr (val , "detach" ):
123+ _cache [key ] = val .detach ().clone ()
124+ else :
125+ _cache [key ] = val
126+ if _cache :
127+ logger .debug ("Cached vision inputs: keys=%s" , list (_cache .keys ()))
128+
129+ return cache_vision_inputs
130+
131+
132+ class VLMModelWrapper :
133+ """Legacy wrapper — delegates to patch_model_for_trl internally.
134+
135+ Kept for backward compatibility with existing code that creates
136+ VLMModelWrapper(model). New code should use patch_model_for_trl()
137+ directly and pass the original model to TRL.
44138 """
45139
46140 def __init__ (self , model : Any ):
47- # Store model WITHOUT going through __setattr__ (which delegates to model)
48141 object .__setattr__ (self , "_vlm_model" , model )
142+ object .__setattr__ (self , "_cache_fn" , patch_model_for_trl (model ))
49143 object .__setattr__ (self , "_vision_cache" , None )
50144 object .__setattr__ (self , "_cache_hits" , 0 )
51145 object .__setattr__ (self , "_cache_misses" , 0 )
52146
53- # --- PEFT / quantization compatibility ---
54- # TRL's validate_quantization_for_training() checks for PEFT via:
55- # 1. isinstance(model, PeftModel) — fails because wrapper isn't PeftModel
56- # 2. hasattr(model, "peft_config") — works via our __getattr__
57- # 3. Checking model.is_quantized / model.quantization_method
58- #
59- # The isinstance check is the blocker. We solve it by making the
60- # wrapper's __class__ inherit from the wrapped model's type, so
61- # isinstance(wrapper, PeftModel) returns True.
147+ # PEFT isinstance compatibility
62148 try :
63149 from peft import PeftModel
64150 if isinstance (model , PeftModel ):
65- # Create a new class that inherits from BOTH our wrapper
66- # and the actual model class. This makes isinstance work
67- # while keeping our forward/generate/cache methods.
68151 combined = type (
69152 "VLMPeftModelWrapper" ,
70153 (VLMModelWrapper , type (model )),
71154 {
72- # Ensure our methods take priority (MRO)
73155 "forward" : VLMModelWrapper .forward ,
74156 "generate" : VLMModelWrapper .generate ,
75157 "__call__" : VLMModelWrapper .__call__ ,
@@ -78,101 +160,28 @@ def __init__(self, model: Any):
78160 },
79161 )
80162 object .__setattr__ (self , "__class__" , combined )
81- logger .info (
82- "VLMModelWrapper: PEFT isinstance compatibility enabled "
83- "(wrapped model is %s)" , type (model ).__name__ ,
84- )
85- except ImportError :
163+ except (ImportError , Exception ):
86164 pass
87- except Exception as exc :
88- # If dynamic class fails, fall back to attribute-level compat
89- logger .warning (
90- "VLMModelWrapper: PEFT isinstance setup failed: %s. "
91- "Falling back to attribute-level compatibility." , exc ,
92- )
93165
94166 def cache_vision_inputs (self , inputs : dict [str , Any ]) -> None :
95- """Cache vision tensors from a processor output dict.
96-
97- Call this during rollout generation, right after processor() and
98- before generate(). The cached tensors will be injected into
99- subsequent forward() calls that lack pixel_values.
100-
101- Args:
102- inputs: Dict from processor(text=..., images=...) containing
103- pixel_values and optionally image_grid_thw.
104- """
105- cache = {}
106- for key in ("pixel_values" , "image_grid_thw" ):
107- if key in inputs :
108- # Clone and detach to avoid gradient issues
109- val = inputs [key ]
110- if hasattr (val , "detach" ):
111- cache [key ] = val .detach ().clone ()
112- else :
113- cache [key ] = val
114- if cache :
115- object .__setattr__ (self , "_vision_cache" , cache )
167+ cache_fn = object .__getattribute__ (self , "_cache_fn" )
168+ cache_fn (inputs )
116169
117170 def forward (self , input_ids : Any = None , ** kwargs : Any ) -> Any :
118- """Forward pass with automatic vision input injection.
119-
120- If kwargs lacks pixel_values and we have cached vision inputs,
121- inject them. This is the key fix: TRL calls model.forward()
122- with only input_ids, but VLMs need pixel_values too.
123- """
124171 model = object .__getattribute__ (self , "_vlm_model" )
125- cache = object .__getattribute__ (self , "_vision_cache" )
126-
127- if "pixel_values" not in kwargs and cache is not None :
128- for key , val in cache .items ():
129- if key not in kwargs :
130- # Move to same device as input_ids
131- if hasattr (val , "to" ) and hasattr (input_ids , "device" ):
132- kwargs [key ] = val .to (input_ids .device )
133- else :
134- kwargs [key ] = val
135- hits = object .__getattribute__ (self , "_cache_hits" )
136- object .__setattr__ (self , "_cache_hits" , hits + 1 )
137- if hits == 0 :
138- logger .info (
139- "VLMModelWrapper: injecting cached vision inputs into "
140- "forward pass (keys=%s). This means TRL called forward() "
141- "without pixel_values — the wrapper is working as intended." ,
142- list (cache .keys ()),
143- )
144- elif "pixel_values" not in kwargs and cache is None :
145- misses = object .__getattribute__ (self , "_cache_misses" )
146- object .__setattr__ (self , "_cache_misses" , misses + 1 )
147- if misses == 0 :
148- logger .warning (
149- "VLMModelWrapper: forward() called without pixel_values "
150- "and no cached vision inputs available. The model is blind. "
151- "Ensure cache_vision_inputs() is called during generation." ,
152- )
153-
154- return model (input_ids = input_ids , ** kwargs )
172+ return model .forward (input_ids = input_ids , ** kwargs )
155173
156174 def generate (self , ** kwargs : Any ) -> Any :
157- """Generate with the underlying model. No interception needed —
158- our generate_fn passes pixel_values explicitly."""
159175 model = object .__getattribute__ (self , "_vlm_model" )
160176 return model .generate (** kwargs )
161177
162178 def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
163- """Route __call__ to forward for compatibility with TRL."""
164179 return self .forward (* args , ** kwargs )
165180
166181 def __getattr__ (self , name : str ) -> Any :
167- """Delegate all other attribute access to the wrapped model.
168-
169- This makes the wrapper transparent: trainer.model.config,
170- trainer.model.parameters(), etc. all work as expected.
171- """
172182 model = object .__getattribute__ (self , "_vlm_model" )
173183 return getattr (model , name )
174184
175185 def __setattr__ (self , name : str , value : Any ) -> None :
176- """Delegate attribute setting to the wrapped model."""
177186 model = object .__getattribute__ (self , "_vlm_model" )
178187 setattr (model , name , value )
0 commit comments