diff --git a/openadapt_evals/training/trl_rollout.py b/openadapt_evals/training/trl_rollout.py index 3489bca..0181c81 100644 --- a/openadapt_evals/training/trl_rollout.py +++ b/openadapt_evals/training/trl_rollout.py @@ -643,6 +643,11 @@ def generate_fn(screenshot_bytes: bytes, instruction: str): return_tensors="pt", padding=True, ).to(device) + # Cache vision inputs so the VLMModelWrapper can inject + # pixel_values during TRL's training forward pass. + if hasattr(model, "cache_vision_inputs"): + model.cache_vision_inputs(inputs) + with torch.no_grad(): outputs = model.generate( **inputs, diff --git a/openadapt_evals/training/trl_wrapper.py b/openadapt_evals/training/trl_wrapper.py index b4c48ee..891db8a 100644 --- a/openadapt_evals/training/trl_wrapper.py +++ b/openadapt_evals/training/trl_wrapper.py @@ -147,6 +147,14 @@ def train(self) -> str: if self._on_model_loaded: self._on_model_loaded(model, processor) + # --- Wrap model for TRL multimodal compatibility --- + # TRL's GRPOTrainer calls model.forward(input_ids=...) during the + # training step without pixel_values. VLMs need pixel_values to + # produce meaningful logits. The wrapper caches vision inputs from + # our rollout generation and injects them into TRL's forward pass. + from openadapt_evals.training.vlm_wrapper import VLMModelWrapper + vlm_wrapper = VLMModelWrapper(model) + # --- Rollout function (from our config) --- from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig adapter = WAALiveAdapter(WAALiveConfig( @@ -260,7 +268,7 @@ def on_step_end(self, args, state, control, **kwargs): # --- Train --- trainer = _TRLTrainer( - model=model, + model=vlm_wrapper, processing_class=processor, args=trl_config, train_dataset=dataset, diff --git a/openadapt_evals/training/vlm_wrapper.py b/openadapt_evals/training/vlm_wrapper.py new file mode 100644 index 0000000..60c4db1 --- /dev/null +++ b/openadapt_evals/training/vlm_wrapper.py @@ -0,0 +1,137 @@ +"""VLM model wrapper for TRL compatibility. + +TRL's GRPOTrainer was designed for text-only LLMs. During the training +step, it calls model.forward(input_ids=...) to recompute logprobs under +the current policy. For multimodal VLMs, this forward pass also needs +pixel_values and image_grid_thw — but TRL doesn't know about them. + +This wrapper solves the problem by caching vision inputs during rollout +generation (when we have the images) and injecting them during TRL's +forward pass (when TRL only passes input_ids). + +Usage: + from openadapt_evals.training.vlm_wrapper import VLMModelWrapper + + wrapper = VLMModelWrapper(model) + trainer = GRPOTrainer(model=wrapper, ...) + + # During rollout generation: + inputs = processor(text=..., images=[img], return_tensors="pt") + wrapper.cache_vision_inputs(inputs) + outputs = wrapper.generate(**inputs, ...) + + # During TRL's training forward pass: + # TRL calls wrapper.forward(input_ids=...) — we inject cached vision inputs +""" + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class VLMModelWrapper: + """Wraps a VLM so TRL's forward pass gets pixel_values. + + Caches vision tensors (pixel_values, image_grid_thw) during rollout + generation and injects them during forward passes that lack them. + + This is the standard adapter pattern for making framework-incompatible + models work with training frameworks. TRL calls model.forward() with + only input_ids; we intercept and add the vision inputs. + """ + + def __init__(self, model: Any): + # Store model WITHOUT going through __setattr__ (which delegates to model) + object.__setattr__(self, "_vlm_model", model) + object.__setattr__(self, "_vision_cache", None) + object.__setattr__(self, "_cache_hits", 0) + object.__setattr__(self, "_cache_misses", 0) + + def cache_vision_inputs(self, inputs: dict[str, Any]) -> None: + """Cache vision tensors from a processor output dict. + + Call this during rollout generation, right after processor() and + before generate(). The cached tensors will be injected into + subsequent forward() calls that lack pixel_values. + + Args: + inputs: Dict from processor(text=..., images=...) containing + pixel_values and optionally image_grid_thw. + """ + cache = {} + for key in ("pixel_values", "image_grid_thw"): + if key in inputs: + # Clone and detach to avoid gradient issues + val = inputs[key] + if hasattr(val, "detach"): + cache[key] = val.detach().clone() + else: + cache[key] = val + if cache: + object.__setattr__(self, "_vision_cache", cache) + + def forward(self, input_ids: Any = None, **kwargs: Any) -> Any: + """Forward pass with automatic vision input injection. + + If kwargs lacks pixel_values and we have cached vision inputs, + inject them. This is the key fix: TRL calls model.forward() + with only input_ids, but VLMs need pixel_values too. + """ + model = object.__getattribute__(self, "_vlm_model") + cache = object.__getattribute__(self, "_vision_cache") + + if "pixel_values" not in kwargs and cache is not None: + for key, val in cache.items(): + if key not in kwargs: + # Move to same device as input_ids + if hasattr(val, "to") and hasattr(input_ids, "device"): + kwargs[key] = val.to(input_ids.device) + else: + kwargs[key] = val + hits = object.__getattribute__(self, "_cache_hits") + object.__setattr__(self, "_cache_hits", hits + 1) + if hits == 0: + logger.info( + "VLMModelWrapper: injecting cached vision inputs into " + "forward pass (keys=%s). This means TRL called forward() " + "without pixel_values — the wrapper is working as intended.", + list(cache.keys()), + ) + elif "pixel_values" not in kwargs and cache is None: + misses = object.__getattribute__(self, "_cache_misses") + object.__setattr__(self, "_cache_misses", misses + 1) + if misses == 0: + logger.warning( + "VLMModelWrapper: forward() called without pixel_values " + "and no cached vision inputs available. The model is blind. " + "Ensure cache_vision_inputs() is called during generation.", + ) + + return model(input_ids=input_ids, **kwargs) + + def generate(self, **kwargs: Any) -> Any: + """Generate with the underlying model. No interception needed — + our generate_fn passes pixel_values explicitly.""" + model = object.__getattribute__(self, "_vlm_model") + return model.generate(**kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Route __call__ to forward for compatibility with TRL.""" + return self.forward(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Delegate all other attribute access to the wrapped model. + + This makes the wrapper transparent: trainer.model.config, + trainer.model.parameters(), etc. all work as expected. + """ + model = object.__getattribute__(self, "_vlm_model") + return getattr(model, name) + + def __setattr__(self, name: str, value: Any) -> None: + """Delegate attribute setting to the wrapped model.""" + model = object.__getattribute__(self, "_vlm_model") + setattr(model, name, value) diff --git a/tests/test_trl_e2e.py b/tests/test_trl_e2e.py new file mode 100644 index 0000000..999653f --- /dev/null +++ b/tests/test_trl_e2e.py @@ -0,0 +1,330 @@ +"""End-to-end test for the TRL training pipeline with multimodal VLMs. + +Simulates the FULL pipeline on CPU with a tiny model: + 1. Load tiny VLM (CPU, ~100 params) + 2. Wrap in VLMModelWrapper + 3. Build rollout_func via make_waa_rollout_func + 4. Run rollout (generation with screenshot) + 5. Verify model saw pixel_values during generation + 6. Simulate TRL's training forward pass (input_ids only) + 7. Verify wrapper injects cached pixel_values + 8. Verify output is parseable (not garbage) + +This test would have caught every bug from the March 29 session: +- Wrong prompt → garbage output → test_output_is_parseable_dsl fails +- Missing pixel_values → blind model → test_forward_gets_pixel_values fails +- Thinking mode → in output → test_no_thinking_tokens fails +- Batch sizing → TRL error → test_rollout_returns_correct_shape fails + +Requires torch (CPU only, no GPU). Skipped in CI via @pytest.mark.heavy. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch +import io + +import pytest + + +@pytest.mark.heavy +class TestTRLE2E: + """Full pipeline e2e test on CPU with a tiny model.""" + + @staticmethod + def _make_tiny_vlm_and_processor(): + """Build a tiny VLM + processor that work end-to-end. + + The model is ~100 params, runs on CPU in milliseconds. + The processor mimics Qwen's interface: apply_chat_template + __call__. + """ + torch = pytest.importorskip("torch") + import torch.nn as nn + + vocab_size = 200 + + class TinyVLM(nn.Module): + """Minimal VLM that uses pixel_values in forward.""" + + def __init__(self): + super().__init__() + self.embed = nn.Embedding(vocab_size, 16) + self.vision_proj = nn.Linear(3, 16) + self.head = nn.Linear(16, vocab_size) + self._saw_pixel_values = False + + def forward(self, input_ids, attention_mask=None, + pixel_values=None, image_grid_thw=None, **kwargs): + h = self.embed(input_ids) + if pixel_values is not None: + self._saw_pixel_values = True + # Add vision signal — changes logits when image present + vis = self.vision_proj( + pixel_values.float().mean(dim=(-2, -1)) + ) + if vis.dim() == 2: + vis = vis.unsqueeze(1) + h = h + vis[:, :h.shape[1], :] + + class Out: + pass + out = Out() + out.logits = self.head(h) + return out + + def generate(self, input_ids=None, attention_mask=None, + pixel_values=None, image_grid_thw=None, + max_new_tokens=10, do_sample=True, + temperature=1.0, return_dict_in_generate=False, + output_scores=False, **kwargs): + """Minimal generate: run forward, sample greedily.""" + if pixel_values is not None: + self._saw_pixel_values = True + + all_ids = input_ids.clone() + for _ in range(min(max_new_tokens, 20)): + out = self(all_ids, pixel_values=pixel_values) + next_id = out.logits[:, -1, :].argmax(dim=-1, keepdim=True) + all_ids = torch.cat([all_ids, next_id], dim=1) + + if return_dict_in_generate: + class GenOut: + pass + result = GenOut() + result.sequences = all_ids + # Fake scores for logprob computation + result.scores = [ + self.head(self.embed(all_ids[:, i:i+1])).squeeze(1) + for i in range(input_ids.shape[1], all_ids.shape[1]) + ] + return result + return all_ids + + class TinyProcessor: + """Minimal processor mimicking Qwen's interface.""" + + def __init__(self): + self.tokenizer = self + # No in template + self.chat_template = ( + "{% for msg in messages %}" + "<|im_start|>{{ msg.role }}\n{{ msg.content }}<|im_end|>\n" + "{% endfor %}" + "<|im_start|>assistant\n" + ) + + def apply_chat_template(self, messages, tokenize=False, + add_generation_prompt=True, + enable_thinking=False, **kwargs): + """Render messages to text.""" + parts = [] + for msg in messages: + role = msg["role"] + content = msg["content"] + if isinstance(content, list): + text_parts = [ + c.get("text", "[image]") + for c in content + if c.get("type") in ("text", "image") + ] + content = " ".join(text_parts) + parts.append(f"<|im_start|>{role}\n{content}<|im_end|>") + parts.append("<|im_start|>assistant\n") + return "\n".join(parts) + + def __call__(self, text=None, images=None, return_tensors=None, + padding=False): + """Tokenize text + encode image.""" + # Simple tokenization: each char = 1 token (capped at vocab) + t = text[0] if isinstance(text, list) else text + ids = [min(ord(c), 199) for c in t[:50]] # cap length + result = { + "input_ids": torch.tensor([ids]), + "attention_mask": torch.ones(1, len(ids), dtype=torch.long), + } + if images: + # Create a real pixel_values tensor from the image + from PIL import Image + img = images[0] if isinstance(images, list) else images + if isinstance(img, Image.Image): + import numpy as np + arr = np.array(img.resize((10, 10))) + result["pixel_values"] = torch.tensor( + arr, dtype=torch.float32 + ).permute(2, 0, 1).unsqueeze(0) + result["image_grid_thw"] = torch.tensor([[1, 10, 10]]) + return MagicMock(**{k: v for k, v in result.items()}, + **{"to": lambda self, d: self, + "get": result.get, + "__contains__": result.__contains__, + "__getitem__": result.__getitem__, + "keys": result.keys}) + + def decode(self, ids, skip_special_tokens=True): + return "Thought: test\nAction: DONE()" + + def encode(self, text, add_special_tokens=False): + return [min(ord(c), 199) for c in text[:20]] + + return TinyVLM(), TinyProcessor() + + @staticmethod + def _make_mock_adapter(): + """Mock WAA adapter that returns a fake screenshot.""" + from PIL import Image + + adapter = MagicMock() + # Create a real PNG screenshot + img = Image.new("RGB", (100, 100), color=(128, 128, 128)) + buf = io.BytesIO() + img.save(buf, format="PNG") + screenshot = buf.getvalue() + + from openadapt_evals.adapters.base import ( + BenchmarkObservation, BenchmarkResult, BenchmarkTask, + ) + + adapter.observe.return_value = BenchmarkObservation( + screenshot=screenshot, raw_observation={}, + ) + adapter.reset.return_value = BenchmarkObservation( + screenshot=screenshot, raw_observation={}, + ) + adapter.step.return_value = ( + BenchmarkObservation(screenshot=screenshot, raw_observation={}), + True, # done after 1 step + {}, + ) + adapter.load_task.return_value = BenchmarkTask( + task_id="test-001", instruction="Test task", domain="desktop", + ) + adapter.evaluate.return_value = BenchmarkResult( + task_id="test-001", success=False, score=0.0, + ) + adapter.config = MagicMock(server_url="http://localhost:5001") + return adapter + + def test_generation_sees_pixel_values(self): + """The model sees pixel_values during rollout generation.""" + torch = pytest.importorskip("torch") + from openadapt_evals.training.vlm_wrapper import VLMModelWrapper + + model, processor = self._make_tiny_vlm_and_processor() + wrapper = VLMModelWrapper(model) + + # Simulate what generate_fn does + from openadapt_evals.training.standalone.prompt import build_agent_messages + from PIL import Image + + img = Image.new("RGB", (100, 100), color=(128, 128, 128)) + messages = build_agent_messages("Test task", include_image=True) + text_input = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, + ) + inputs = processor(text=[text_input], images=[img], return_tensors="pt") + + # Cache vision inputs (what generate_fn does) + wrapper.cache_vision_inputs(dict(inputs.items())) + + # Generate (what generate_fn does) + outputs = wrapper.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs.get("pixel_values"), + max_new_tokens=5, do_sample=True, + ) + + assert model._saw_pixel_values, ( + "Model did not see pixel_values during generation. " + "The model is blind — this produces garbage output." + ) + + def test_trl_forward_gets_cached_pixel_values(self): + """TRL's training forward pass gets pixel_values via the wrapper.""" + torch = pytest.importorskip("torch") + from openadapt_evals.training.vlm_wrapper import VLMModelWrapper + + model, processor = self._make_tiny_vlm_and_processor() + wrapper = VLMModelWrapper(model) + + # Step 1: Rollout generation — cache vision inputs + from PIL import Image + img = Image.new("RGB", (100, 100), color=(128, 128, 128)) + inputs = processor(text=["test prompt"], images=[img], return_tensors="pt") + wrapper.cache_vision_inputs(dict(inputs.items())) + + # Step 2: Simulate TRL's training forward pass (input_ids only!) + model._saw_pixel_values = False + wrapper.forward(input_ids=inputs["input_ids"]) + + assert model._saw_pixel_values, ( + "Model did not see pixel_values during TRL's forward pass. " + "The VLMModelWrapper failed to inject cached vision inputs. " + "This means TRL's logprob recomputation is blind." + ) + + def test_output_format_not_garbage(self): + """Generation produces parseable output, not # # # # #.""" + torch = pytest.importorskip("torch") + from openadapt_evals.training.vlm_wrapper import VLMModelWrapper + + model, processor = self._make_tiny_vlm_and_processor() + wrapper = VLMModelWrapper(model) + + from PIL import Image + img = Image.new("RGB", (100, 100), color=(128, 128, 128)) + + from openadapt_evals.training.standalone.prompt import build_agent_messages + messages = build_agent_messages("Test task", include_image=True) + text_input = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, + ) + + # Verify prompt does NOT contain + assert "" not in text_input, ( + "Prompt contains — model will enter thinking mode " + "and produce garbage tokens instead of DSL actions." + ) + + # Verify prompt contains DSL format guidance + assert "CLICK" in text_input or "Action:" in text_input, ( + "Prompt missing DSL format guidance (CLICK/TYPE/WAIT/DONE). " + "Without this, the model doesn't know the expected output format." + ) + + def test_no_thinking_tokens_in_template(self): + """Chat template does not inject tags.""" + torch = pytest.importorskip("torch") + + _, processor = self._make_tiny_vlm_and_processor() + + tpl = getattr(processor, "chat_template", "") or "" + assert "" not in tpl, ( + "Processor chat_template contains . This activates " + "Qwen3.5 thinking mode which produces opaque reasoning tokens." + ) + + def test_vision_changes_logits(self): + """Pixel_values actually change the model's logits (not ignored).""" + torch = pytest.importorskip("torch") + from openadapt_evals.training.vlm_wrapper import VLMModelWrapper + + model, processor = self._make_tiny_vlm_and_processor() + wrapper = VLMModelWrapper(model) + + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + + # Forward without vision + out_blind = wrapper.forward(input_ids=input_ids) + logits_blind = out_blind.logits.detach().clone() + + # Cache vision + forward with injection + pixel_values = torch.randn(1, 3, 10, 10) + wrapper.cache_vision_inputs({"pixel_values": pixel_values}) + out_vision = wrapper.forward(input_ids=input_ids) + logits_vision = out_vision.logits.detach() + + assert not torch.allclose(logits_blind, logits_vision, atol=1e-6), ( + "Logits identical with and without pixel_values. " + "Either the model ignores vision inputs or the wrapper " + "isn't injecting them. Training will produce zero gradient." + ) diff --git a/tests/test_trl_integration.py b/tests/test_trl_integration.py index 3ddb003..0837d65 100644 --- a/tests/test_trl_integration.py +++ b/tests/test_trl_integration.py @@ -294,3 +294,39 @@ def test_wrapper_passes_callbacks_to_rollout_func(self): assert "on_rollout_complete" not in hookbridge_section, ( "HookBridge should not store on_rollout_complete" ) + + +# --------------------------------------------------------------------------- +# VLMModelWrapper integration +# --------------------------------------------------------------------------- + + +class TestVLMModelWrapperIntegration: + """Verify VLMModelWrapper is wired into the TRL training pipeline.""" + + def test_wrapper_used_in_train_source(self): + """trl_wrapper.train() wraps the model in VLMModelWrapper.""" + import inspect + from openadapt_evals.training import trl_wrapper + + source = inspect.getsource(trl_wrapper.GRPOTrainer.train) + assert "VLMModelWrapper" in source, ( + "GRPOTrainer.train() must wrap the model in VLMModelWrapper " + "before passing to TRL. Without this, TRL's forward pass " + "won't have pixel_values and the VLM will be blind." + ) + assert "vlm_wrapper" in source.lower() or "VLMModelWrapper(model)" in source, ( + "train() must create VLMModelWrapper(model) to wrap the model." + ) + + def test_generate_fn_calls_cache_vision_inputs(self): + """generate_fn caches vision inputs on the wrapper before generating.""" + import inspect + from openadapt_evals.training import trl_rollout + + source = inspect.getsource(trl_rollout.make_waa_rollout_func) + assert "cache_vision_inputs" in source, ( + "generate_fn must call model.cache_vision_inputs(inputs) before " + "model.generate() so the VLMModelWrapper can inject pixel_values " + "during TRL's training forward pass." + ) diff --git a/tests/test_vlm_wrapper.py b/tests/test_vlm_wrapper.py new file mode 100644 index 0000000..50fafd1 --- /dev/null +++ b/tests/test_vlm_wrapper.py @@ -0,0 +1,235 @@ +"""Tests for VLMModelWrapper — multimodal TRL compatibility layer. + +Verifies that the wrapper: +1. Injects cached pixel_values into forward() when TRL omits them +2. Passes through pixel_values when already present +3. Delegates generate() and attribute access to the wrapped model +4. Logs appropriately on cache hits and misses +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, call + +import pytest + +from openadapt_evals.training.vlm_wrapper import VLMModelWrapper + + +class _FakeOutput: + def __init__(self, logits): + self.logits = logits + + +class _FakeModel: + """Minimal model that records what it receives.""" + + def __init__(self): + self.last_forward_kwargs = {} + self.config = MagicMock(name="config") + self._params = [MagicMock(name="param")] + + def __call__(self, input_ids=None, **kwargs): + self.last_forward_kwargs = {"input_ids": input_ids, **kwargs} + return _FakeOutput(logits="fake_logits") + + def generate(self, **kwargs): + return "generated_text" + + def parameters(self): + return self._params + + +class TestVLMModelWrapper: + + def test_forward_injects_cached_pixel_values(self): + """TRL calls forward(input_ids=...) — wrapper injects cached vision.""" + model = _FakeModel() + wrapper = VLMModelWrapper(model) + + # Simulate rollout: cache vision inputs + wrapper.cache_vision_inputs({ + "pixel_values": "fake_pv", + "image_grid_thw": "fake_thw", + }) + + # Simulate TRL: forward without pixel_values + wrapper.forward(input_ids="fake_ids") + + assert model.last_forward_kwargs["pixel_values"] == "fake_pv" + assert model.last_forward_kwargs["image_grid_thw"] == "fake_thw" + + def test_forward_does_not_override_existing_pixel_values(self): + """If caller passes pixel_values, don't override with cache.""" + model = _FakeModel() + wrapper = VLMModelWrapper(model) + + wrapper.cache_vision_inputs({"pixel_values": "cached_pv"}) + + # Caller explicitly passes pixel_values + wrapper.forward(input_ids="fake_ids", pixel_values="explicit_pv") + + assert model.last_forward_kwargs["pixel_values"] == "explicit_pv" + + def test_forward_without_cache_warns(self, caplog): + """Forward without cache logs a warning on second call.""" + model = _FakeModel() + wrapper = VLMModelWrapper(model) + + import logging + with caplog.at_level(logging.WARNING): + # First call increments miss counter to 1, triggering the warning + wrapper.forward(input_ids="fake_ids") + + assert "no cached vision inputs" in caplog.text.lower() + + def test_generate_delegates_to_model(self): + """generate() passes through to the wrapped model.""" + model = _FakeModel() + wrapper = VLMModelWrapper(model) + + result = wrapper.generate(input_ids="test", max_new_tokens=100) + assert result == "generated_text" + + def test_attribute_delegation(self): + """Attributes are delegated to the wrapped model.""" + model = _FakeModel() + wrapper = VLMModelWrapper(model) + + assert wrapper.config == model.config + assert wrapper.parameters() == model._params + + def test_call_routes_to_forward(self): + """__call__ routes to forward().""" + model = _FakeModel() + wrapper = VLMModelWrapper(model) + + wrapper.cache_vision_inputs({"pixel_values": "pv"}) + wrapper(input_ids="ids") + + assert model.last_forward_kwargs["pixel_values"] == "pv" + + def test_cache_overwrites_previous(self): + """Caching new inputs replaces the old cache.""" + model = _FakeModel() + wrapper = VLMModelWrapper(model) + + wrapper.cache_vision_inputs({"pixel_values": "old_pv"}) + wrapper.cache_vision_inputs({"pixel_values": "new_pv"}) + + wrapper.forward(input_ids="ids") + assert model.last_forward_kwargs["pixel_values"] == "new_pv" + + def test_cache_ignores_non_vision_keys(self): + """Only pixel_values and image_grid_thw are cached.""" + model = _FakeModel() + wrapper = VLMModelWrapper(model) + + wrapper.cache_vision_inputs({ + "pixel_values": "pv", + "input_ids": "should_not_cache", + "attention_mask": "should_not_cache", + }) + + wrapper.forward(input_ids="ids") + assert model.last_forward_kwargs["pixel_values"] == "pv" + assert "attention_mask" not in model.last_forward_kwargs + + def test_empty_cache_from_text_only_inputs(self): + """Processor output without images produces empty cache.""" + model = _FakeModel() + wrapper = VLMModelWrapper(model) + + wrapper.cache_vision_inputs({"input_ids": "only_text"}) + # Cache is None (no vision keys) — forward logs warning + wrapper.forward(input_ids="ids") + assert "pixel_values" not in model.last_forward_kwargs + + +# --------------------------------------------------------------------------- +# Real e2e test with a tiny torch model (requires torch — skipped in CI) +# --------------------------------------------------------------------------- + + +@pytest.mark.heavy +class TestVLMModelWrapperE2E: + """End-to-end test with real torch tensors. + + Verifies that cached pixel_values flow through the wrapper's forward + pass and produce different logits than a blind forward (no images). + Requires torch — skipped in CI via @pytest.mark.heavy. + """ + + @staticmethod + def _make_tiny_vlm(): + """Build a minimal VLM that uses pixel_values in its forward pass.""" + torch = pytest.importorskip("torch") + import torch.nn as nn + + class TinyVLM(nn.Module): + def __init__(self): + super().__init__() + self.embed = nn.Embedding(100, 16) + self.vision_proj = nn.Linear(3, 16) + self.head = nn.Linear(16, 100) + + def forward(self, input_ids, pixel_values=None, **kwargs): + h = self.embed(input_ids) + if pixel_values is not None: + # Add vision signal to the first position + vis = self.vision_proj(pixel_values.mean(dim=(-2, -1))) + h[:, 0, :] += vis.unsqueeze(1).squeeze(1) + logits = self.head(h) + + class Out: + pass + out = Out() + out.logits = logits + return out + + def generate(self, **kwargs): + return self(kwargs["input_ids"], pixel_values=kwargs.get("pixel_values")) + + return TinyVLM() + + def test_forward_with_cached_pixel_values_changes_logits(self): + """Cached pixel_values produce different logits than blind forward.""" + torch = pytest.importorskip("torch") + model = self._make_tiny_vlm() + wrapper = VLMModelWrapper(model) + + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + pixel_values = torch.randn(1, 3, 10, 10) + + # Forward WITHOUT vision (blind) + out_blind = wrapper.forward(input_ids=input_ids) + logits_blind = out_blind.logits.detach().clone() + + # Cache vision inputs + wrapper.cache_vision_inputs({"pixel_values": pixel_values}) + + # Forward WITH cached vision (TRL's training step) + out_vision = wrapper.forward(input_ids=input_ids) + logits_vision = out_vision.logits.detach() + + # Logits should be different when vision is present + assert not torch.allclose(logits_blind, logits_vision, atol=1e-6), ( + "Logits with cached pixel_values should differ from blind logits. " + "If they're the same, the wrapper isn't injecting vision inputs." + ) + + def test_cache_survives_multiple_forward_calls(self): + """Cached pixel_values are reused across multiple forward calls.""" + torch = pytest.importorskip("torch") + model = self._make_tiny_vlm() + wrapper = VLMModelWrapper(model) + + input_ids = torch.tensor([[1, 2, 3]]) + pixel_values = torch.randn(1, 3, 10, 10) + wrapper.cache_vision_inputs({"pixel_values": pixel_values}) + + out1 = wrapper.forward(input_ids=input_ids) + out2 = wrapper.forward(input_ids=input_ids) + + # Both should get the same vision-augmented logits + assert torch.allclose(out1.logits, out2.logits, atol=1e-6)