Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions openadapt_evals/training/vlm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,47 @@ def __init__(self, model: Any):
object.__setattr__(self, "_cache_hits", 0)
object.__setattr__(self, "_cache_misses", 0)

# --- PEFT / quantization compatibility ---
# TRL's validate_quantization_for_training() checks for PEFT via:
# 1. isinstance(model, PeftModel) — fails because wrapper isn't PeftModel
# 2. hasattr(model, "peft_config") — works via our __getattr__
# 3. Checking model.is_quantized / model.quantization_method
#
# The isinstance check is the blocker. We solve it by making the
# wrapper's __class__ inherit from the wrapped model's type, so
# isinstance(wrapper, PeftModel) returns True.
try:
from peft import PeftModel
if isinstance(model, PeftModel):
# Create a new class that inherits from BOTH our wrapper
# and the actual model class. This makes isinstance work
# while keeping our forward/generate/cache methods.
combined = type(
"VLMPeftModelWrapper",
(VLMModelWrapper, type(model)),
{
# Ensure our methods take priority (MRO)
"forward": VLMModelWrapper.forward,
"generate": VLMModelWrapper.generate,
"__call__": VLMModelWrapper.__call__,
"cache_vision_inputs": VLMModelWrapper.cache_vision_inputs,
"__getattr__": VLMModelWrapper.__getattr__,
},
)
object.__setattr__(self, "__class__", combined)
logger.info(
"VLMModelWrapper: PEFT isinstance compatibility enabled "
"(wrapped model is %s)", type(model).__name__,
)
except ImportError:
pass
except Exception as exc:
# If dynamic class fails, fall back to attribute-level compat
logger.warning(
"VLMModelWrapper: PEFT isinstance setup failed: %s. "
"Falling back to attribute-level compatibility.", exc,
)

def cache_vision_inputs(self, inputs: dict[str, Any]) -> None:
"""Cache vision tensors from a processor output dict.

Expand Down
53 changes: 53 additions & 0 deletions tests/test_trl_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,56 @@ def test_vision_changes_logits(self):
"Either the model ignores vision inputs or the wrapper "
"isn't injecting them. Training will produce zero gradient."
)

def test_wrapper_passes_peft_validation(self):
"""VLMModelWrapper passes TRL's PEFT/quantization validation.

TRL checks isinstance(model, PeftModel) to verify adapters are
attached to quantized models. The wrapper must pass this check.
Without it: ValueError: "You cannot perform fine-tuning on purely
quantized models. Please attach trainable adapters."
"""
torch = pytest.importorskip("torch")

try:
from peft import PeftModel
except ImportError:
pytest.skip("peft not installed")

from openadapt_evals.training.vlm_wrapper import VLMModelWrapper

# Create a mock PeftModel (has peft_config, active_adapter, etc.)
model = MagicMock(spec=PeftModel)
model.peft_config = {"default": MagicMock()}
model.active_adapter = "default"
model.parameters.return_value = iter([torch.zeros(1, requires_grad=True)])

wrapper = VLMModelWrapper(model)

# The critical check TRL performs
assert isinstance(wrapper, PeftModel), (
"isinstance(wrapper, PeftModel) must be True. "
"TRL rejects quantized models without PEFT adapters."
)
assert hasattr(wrapper, "peft_config"), (
"wrapper must expose peft_config for TRL validation."
)

def test_wrapper_preserves_trainable_parameters(self):
"""VLMModelWrapper exposes trainable parameters for the optimizer.

TRL needs model.parameters() to set up the optimizer. The wrapper
must delegate this to the wrapped model.
"""
torch = pytest.importorskip("torch")
from openadapt_evals.training.vlm_wrapper import VLMModelWrapper

model, _ = self._make_tiny_vlm_and_processor()
wrapper = VLMModelWrapper(model)

# Verify parameters are accessible and trainable
params = list(wrapper.parameters())
assert len(params) > 0, "Wrapper must expose model parameters"
assert any(p.requires_grad for p in params), (
"At least some parameters must require grad for training"
)
46 changes: 46 additions & 0 deletions tests/test_vlm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,52 @@ def test_empty_cache_from_text_only_inputs(self):
wrapper.forward(input_ids="ids")
assert "pixel_values" not in model.last_forward_kwargs

def test_peft_attributes_delegated(self):
"""PEFT attributes are accessible through the wrapper."""
model = _FakeModel()
model.peft_config = {"default": "lora_config"}
model.active_adapter = "default"
wrapper = VLMModelWrapper(model)

assert wrapper.peft_config == {"default": "lora_config"}
assert wrapper.active_adapter == "default"

def test_hasattr_peft_config(self):
"""hasattr(wrapper, 'peft_config') returns True when model has it."""
model = _FakeModel()
model.peft_config = {"default": "config"}
wrapper = VLMModelWrapper(model)

assert hasattr(wrapper, "peft_config"), (
"hasattr(wrapper, 'peft_config') must return True for TRL's "
"validate_quantization_for_training() to pass."
)

def test_hasattr_peft_config_false_when_missing(self):
"""hasattr(wrapper, 'peft_config') returns False when model lacks it."""
model = _FakeModel()
wrapper = VLMModelWrapper(model)

assert not hasattr(wrapper, "peft_config")

def test_isinstance_peft_model(self):
"""isinstance(wrapper, PeftModel) works when PEFT is available."""
try:
from peft import PeftModel
except ImportError:
pytest.skip("peft not installed")

# Create a mock that isinstance recognizes as PeftModel
model = MagicMock(spec=PeftModel)
model.peft_config = {"default": "config"}
wrapper = VLMModelWrapper(model)

assert isinstance(wrapper, PeftModel), (
"isinstance(wrapper, PeftModel) must return True. "
"TRL's validation uses isinstance to detect PEFT adapters. "
"Without this, TRL rejects quantized models."
)


# ---------------------------------------------------------------------------
# Real e2e test with a tiny torch model (requires torch — skipped in CI)
Expand Down
Loading