diff --git a/garak/generators/huggingface.py b/garak/generators/huggingface.py index 549d24f05..19ad33267 100644 --- a/garak/generators/huggingface.py +++ b/garak/generators/huggingface.py @@ -575,7 +575,11 @@ def __init__(self, name="", config_root=_config): for k, v in generation_params.items(): setattr(self.model.generation_config, k, v) - self.model.to(self.device) + # A model dispatched by accelerate (hf_device_map set) can have modules + # offloaded to CPU/disk, which makes .to() raise; move it only when + # accelerate did not place it. + if not getattr(self.model, "hf_device_map", None): + self.model.to(self.device) if stored_env: os.environ[disable_env_key] = stored_env @@ -587,11 +591,19 @@ def generate( ) -> List[Union[Message, None]]: text_prompt = prompt.last_message().text + image_path = prompt.last_message().data_path + if image_path is None: + # LLaVA needs an image; a text-only prompt has none to open. + raise GarakException( + f"{self.name} requires an image but received a text-only prompt. " + "Use an image+text probe (e.g. visual_jailbreak.*), set a " + "Message.data_path, or run with strict_modality_match to skip " + "text-only probes for this target." + ) try: - image_prompt = self.PIL.Image.open(prompt.last_message().data_path) + image_prompt = self.PIL.Image.open(image_path) except FileNotFoundError: - file_path = prompt.last_message().data_path - raise FileNotFoundError(f"Cannot open image {file_path}.") + raise FileNotFoundError(f"Cannot open image {image_path}.") except Exception as e: raise Exception(e) diff --git a/tests/generators/test_llava.py b/tests/generators/test_llava.py index 71574d406..554670eec 100644 --- a/tests/generators/test_llava.py +++ b/tests/generators/test_llava.py @@ -4,7 +4,7 @@ from garak.attempt import Conversation, Turn, Message from garak._config import GarakSubConfig -from garak.exception import TargetNameMissingError +from garak.exception import GarakException, TargetNameMissingError try: from PIL import Image, ImageDraw @@ -132,3 +132,87 @@ def test_llava_supported_models_list(): assert len(SUPPORTED_MODELS) > 0 for model in SUPPORTED_MODELS: assert model.startswith("llava-hf/") + + +def _model_mock(hf_device_map): + """Fake model whose `.to` is observable; `hf_device_map` mimics accelerate + dispatch (a dict) or no dispatch (None).""" + model = MagicMock(name="Model") + model.to = MagicMock(name="to") + model.hf_device_map = hf_device_map + return model + + +@pytest.mark.parametrize( + "hf_device_map, should_move", [({"": "cpu"}, False), (None, True)] +) +def test_llava_model_move_respects_device_map( + llava_config, monkeypatch, hf_device_map, should_move +): + """__init__ must not call `model.to()` when accelerate dispatched the + model (`hf_device_map` set) — that raises on offloaded models — but must move + it when accelerate did not place it.""" + model = _model_mock(hf_device_map) + monkeypatch.setattr( + "transformers.LlavaNextForConditionalGeneration.from_pretrained", + lambda name, **kw: model, + ) + llava = LLaVA(name=SUPPORTED_MODELS[0], config_root=llava_config) + if should_move: + model.to.assert_called_once_with(llava.device) + else: + model.to.assert_not_called() + + +def test_llava_error_on_text_only_prompt(llava_config): + """Text-only prompt (`data_path is None`) raises a GarakException correctly.""" + llava = LLaVA(name=SUPPORTED_MODELS[0], config_root=llava_config) + with pytest.raises(GarakException): + llava.generate(Conversation([Turn("user", Message(text="foo"))])) + + +def test_llava_generate_runs_when_model_dispatched_across_devices( + llava_config, llava_test_image, tmp_path +): + """A model split across devices by accelerate raises on `.to()` (which is why + __init__ guards the move), yet generate() still runs: the processor output is a + BatchFeature, so `.to(self.device)` is an ordinary tensor move and accelerate's + hooks realign inputs per submodule at forward time.""" + accelerate = pytest.importorskip("accelerate") + import torch.nn as nn + from types import SimpleNamespace + from transformers.feature_extraction_utils import BatchFeature + + class TinyVLM(nn.Module): + def __init__(self): + super().__init__() + self.embed = nn.Embedding(16, 8) + self.head = nn.Linear(8, 16) + self.generation_config = SimpleNamespace(max_new_tokens=None) + + def generate(self, **inputs): + return self.head(self.embed(inputs["input_ids"])).argmax(dim=-1) + + # cpu + disk split -> multi-entry hf_device_map with offloaded params on `meta` + model = accelerate.dispatch_model( + TinyVLM().eval(), + device_map={"embed": "cpu", "head": "disk"}, + offload_dir=str(tmp_path / "offload"), + ) + assert len(model.hf_device_map) > 1 + # the move __init__ guards against: this dispatched model cannot be `.to()`-moved + with pytest.raises(RuntimeError): + model.to(torch.device("cpu")) + + processor = MagicMock() + processor.return_value = BatchFeature({"input_ids": torch.tensor([[1, 2, 3]])}) + processor.decode.return_value = "across-device output" + + llava = LLaVA(name=SUPPORTED_MODELS[0], config_root=llava_config) + llava.model = model + llava.processor = processor + llava.device = torch.device("cpu") + + conv = Conversation([Turn("user", Message(text="x", data_path=llava_test_image))]) + # generate() completes even though `.to()` on this same model would have raised + assert llava.generate(conv) == [Message("across-device output")]