Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,11 @@ def __init__(self, name="", config_root=_config):
for k, v in generation_params.items():
setattr(self.model.generation_config, k, v)

self.model.to(self.device)
# A model dispatched by accelerate (hf_device_map set) can have modules
# offloaded to CPU/disk, which makes .to() raise; move it only when
# accelerate did not place it.
if not getattr(self.model, "hf_device_map", None):
self.model.to(self.device)

if stored_env:
os.environ[disable_env_key] = stored_env
Expand All @@ -587,11 +591,19 @@ def generate(
) -> List[Union[Message, None]]:

text_prompt = prompt.last_message().text
image_path = prompt.last_message().data_path
if image_path is None:
# LLaVA needs an image; a text-only prompt has none to open.
raise GarakException(
f"{self.name} requires an image but received a text-only prompt. "
"Use an image+text probe (e.g. visual_jailbreak.*), set a "
"Message.data_path, or run with strict_modality_match to skip "
"text-only probes for this target."
)
try:
image_prompt = self.PIL.Image.open(prompt.last_message().data_path)
image_prompt = self.PIL.Image.open(image_path)
except FileNotFoundError:
file_path = prompt.last_message().data_path
raise FileNotFoundError(f"Cannot open image {file_path}.")
raise FileNotFoundError(f"Cannot open image {image_path}.")
except Exception as e:
raise Exception(e)

Expand Down
86 changes: 85 additions & 1 deletion tests/generators/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from garak.attempt import Conversation, Turn, Message
from garak._config import GarakSubConfig
from garak.exception import TargetNameMissingError
from garak.exception import GarakException, TargetNameMissingError

try:
from PIL import Image, ImageDraw
Expand Down Expand Up @@ -132,3 +132,87 @@ def test_llava_supported_models_list():
assert len(SUPPORTED_MODELS) > 0
for model in SUPPORTED_MODELS:
assert model.startswith("llava-hf/")


def _model_mock(hf_device_map):
"""Fake model whose `.to` is observable; `hf_device_map` mimics accelerate
dispatch (a dict) or no dispatch (None)."""
model = MagicMock(name="Model")
model.to = MagicMock(name="to")
model.hf_device_map = hf_device_map
return model


@pytest.mark.parametrize(
"hf_device_map, should_move", [({"": "cpu"}, False), (None, True)]
)
def test_llava_model_move_respects_device_map(
llava_config, monkeypatch, hf_device_map, should_move
):
"""__init__ must not call `model.to()` when accelerate dispatched the
model (`hf_device_map` set) — that raises on offloaded models — but must move
it when accelerate did not place it."""
model = _model_mock(hf_device_map)
monkeypatch.setattr(
"transformers.LlavaNextForConditionalGeneration.from_pretrained",
lambda name, **kw: model,
)
llava = LLaVA(name=SUPPORTED_MODELS[0], config_root=llava_config)
if should_move:
model.to.assert_called_once_with(llava.device)
else:
model.to.assert_not_called()


def test_llava_error_on_text_only_prompt(llava_config):
"""Text-only prompt (`data_path is None`) raises a GarakException correctly."""
llava = LLaVA(name=SUPPORTED_MODELS[0], config_root=llava_config)
with pytest.raises(GarakException):
llava.generate(Conversation([Turn("user", Message(text="foo"))]))


def test_llava_generate_runs_when_model_dispatched_across_devices(
llava_config, llava_test_image, tmp_path
):
"""A model split across devices by accelerate raises on `.to()` (which is why
__init__ guards the move), yet generate() still runs: the processor output is a
BatchFeature, so `.to(self.device)` is an ordinary tensor move and accelerate's
hooks realign inputs per submodule at forward time."""
accelerate = pytest.importorskip("accelerate")
import torch.nn as nn
from types import SimpleNamespace
from transformers.feature_extraction_utils import BatchFeature

class TinyVLM(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(16, 8)
self.head = nn.Linear(8, 16)
self.generation_config = SimpleNamespace(max_new_tokens=None)

def generate(self, **inputs):
return self.head(self.embed(inputs["input_ids"])).argmax(dim=-1)

# cpu + disk split -> multi-entry hf_device_map with offloaded params on `meta`
model = accelerate.dispatch_model(
TinyVLM().eval(),
device_map={"embed": "cpu", "head": "disk"},
offload_dir=str(tmp_path / "offload"),
)
assert len(model.hf_device_map) > 1
# the move __init__ guards against: this dispatched model cannot be `.to()`-moved
with pytest.raises(RuntimeError):
model.to(torch.device("cpu"))

processor = MagicMock()
processor.return_value = BatchFeature({"input_ids": torch.tensor([[1, 2, 3]])})
processor.decode.return_value = "across-device output"

llava = LLaVA(name=SUPPORTED_MODELS[0], config_root=llava_config)
llava.model = model
llava.processor = processor
llava.device = torch.device("cpu")

conv = Conversation([Turn("user", Message(text="x", data_path=llava_test_image))])
# generate() completes even though `.to()` on this same model would have raised
assert llava.generate(conv) == [Message("across-device output")]
Loading