diff --git a/openadapt_evals/training/vlm_wrapper.py b/openadapt_evals/training/vlm_wrapper.py index 003b4ff..a6f16c8 100644 --- a/openadapt_evals/training/vlm_wrapper.py +++ b/openadapt_evals/training/vlm_wrapper.py @@ -82,29 +82,37 @@ def _patched_forward(input_ids: Any = None, **kwargs: Any) -> Any: # Patch the model instance model.forward = _patched_forward - # Also patch __call__ if it routes to forward (most HF models do) - # This ensures model(input_ids=...) also gets the injection. - original_call = model.__class__.__call__ - - def _patched_call(self_model, *args, **kwargs): - # If called without pixel_values, inject from cache + # Also patch generate() — TRL calls model.generate(input_ids=...) + # without pixel_values. HF's generate() calls forward() internally, + # but pixel_values must also be in the generate() kwargs so HF can + # pass them through prepare_inputs_for_generation() → forward(). + _logged_gen_inject = [False] + original_generate = model.generate + + def _patched_generate(**kwargs: Any) -> Any: + """Generate with automatic vision input injection.""" if "pixel_values" not in kwargs and _cache: for key, val in _cache.items(): if key not in kwargs: - input_ids = kwargs.get("input_ids", args[0] if args else None) + input_ids = kwargs.get("input_ids") if hasattr(val, "to") and input_ids is not None and hasattr(input_ids, "device"): kwargs[key] = val.to(input_ids.device) else: kwargs[key] = val - return original_call(self_model, *args, **kwargs) + if not _logged_gen_inject[0]: + _logged_gen_inject[0] = True + logger.info( + "VLM generate patch: injecting cached vision inputs " + "(keys=%s). TRL called generate() without pixel_values.", + list(_cache.keys()), + ) + return original_generate(**kwargs) - # Only patch __call__ on the instance, not the class - import types - model.__call__ = types.MethodType(_patched_call, model) + model.generate = _patched_generate logger.info( - "VLM forward patch installed on %s. Vision inputs will be " - "auto-injected during TRL's forward passes.", + "VLM patches installed on %s (forward + generate). Vision inputs " + "will be auto-injected during all TRL model calls.", type(model).__name__, )