Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions openadapt_evals/training/vlm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,29 +82,37 @@ def _patched_forward(input_ids: Any = None, **kwargs: Any) -> Any:
# Patch the model instance
model.forward = _patched_forward

# Also patch __call__ if it routes to forward (most HF models do)
# This ensures model(input_ids=...) also gets the injection.
original_call = model.__class__.__call__

def _patched_call(self_model, *args, **kwargs):
# If called without pixel_values, inject from cache
# Also patch generate() — TRL calls model.generate(input_ids=...)
# without pixel_values. HF's generate() calls forward() internally,
# but pixel_values must also be in the generate() kwargs so HF can
# pass them through prepare_inputs_for_generation() → forward().
_logged_gen_inject = [False]
original_generate = model.generate

def _patched_generate(**kwargs: Any) -> Any:
"""Generate with automatic vision input injection."""
if "pixel_values" not in kwargs and _cache:
for key, val in _cache.items():
if key not in kwargs:
input_ids = kwargs.get("input_ids", args[0] if args else None)
input_ids = kwargs.get("input_ids")
if hasattr(val, "to") and input_ids is not None and hasattr(input_ids, "device"):
kwargs[key] = val.to(input_ids.device)
else:
kwargs[key] = val
return original_call(self_model, *args, **kwargs)
if not _logged_gen_inject[0]:
_logged_gen_inject[0] = True
logger.info(
"VLM generate patch: injecting cached vision inputs "
"(keys=%s). TRL called generate() without pixel_values.",
list(_cache.keys()),
)
return original_generate(**kwargs)

# Only patch __call__ on the instance, not the class
import types
model.__call__ = types.MethodType(_patched_call, model)
model.generate = _patched_generate

logger.info(
"VLM forward patch installed on %s. Vision inputs will be "
"auto-injected during TRL's forward passes.",
"VLM patches installed on %s (forward + generate). Vision inputs "
"will be auto-injected during all TRL model calls.",
type(model).__name__,
)

Expand Down
Loading