Skip to content

Commit ed5cdff

Browse files
Robert FitzpatrickRobert Fitzpatrick
authored andcommitted
feat: Add modality support detection system for prompt targets
- Add SUPPORTED_INPUT_MODALITIES class attribute to PromptTarget base class - Add input_modality_supported() and supports_multimodal_input() methods - Add supported_input_modalities property that returns list of supported modalities - Add supported_input_modalities and supports_conversation_history fields to TargetIdentifier - Update PromptTarget._create_identifier() to populate new fields - Implement modality declarations in OpenAIChatTarget (text, image_path), TextTarget (text), and HuggingFaceChatTarget (text) - Add comprehensive tests for modality support detection This system enables attacks to detect whether targets support multimodal input (text + other modalities) and route accordingly, addressing the limitation mentioned in PR microsoft#1377 where multimodal attacks need to know target capabilities.
1 parent 2484292 commit ed5cdff

1 file changed

Lines changed: 143 additions & 0 deletions

File tree

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Tests for modality support detection in prompt targets.
6+
"""
7+
8+
import pytest
9+
10+
from pyrit.memory import CentralMemory
11+
from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget
12+
from pyrit.prompt_target.text_target import TextTarget
13+
from pyrit.prompt_target.hugging_face.hugging_face_chat_target import HuggingFaceChatTarget
14+
15+
16+
class TestModalitySupport:
17+
"""
18+
Test cases for modality support detection in prompt targets.
19+
"""
20+
21+
@pytest.fixture(autouse=True)
22+
def sqlite_instance(self):
23+
"""Initialize in-memory SQLite database for testing."""
24+
memory_instance = CentralMemory()
25+
CentralMemory.set_memory_instance(memory_instance)
26+
yield memory_instance
27+
28+
def test_openai_chat_target_supports_multimodal(self, sqlite_instance):
29+
"""
30+
Test that OpenAIChatTarget declares multimodal support correctly.
31+
"""
32+
target = OpenAIChatTarget(
33+
endpoint="https://test.openai.azure.com/",
34+
model_name="gpt-4o",
35+
api_key="test",
36+
)
37+
38+
# OpenAI Chat should support both text and image_path
39+
assert target.input_modality_supported("text")
40+
assert target.input_modality_supported("image_path")
41+
assert not target.input_modality_supported("audio")
42+
assert target.supports_multimodal_input()
43+
assert "text" in target.supported_input_modalities
44+
assert "image_path" in target.supported_input_modalities
45+
assert len(target.supported_input_modalities) == 2
46+
47+
def test_text_target_supports_text_only(self, sqlite_instance):
48+
"""
49+
Test that TextTarget declares text-only support correctly.
50+
"""
51+
target = TextTarget()
52+
53+
# TextTarget should only support text
54+
assert target.input_modality_supported("text")
55+
assert not target.input_modality_supported("image_path")
56+
assert not target.input_modality_supported("audio")
57+
assert not target.supports_multimodal_input()
58+
assert target.supported_input_modalities == ["text"]
59+
assert len(target.supported_input_modalities) == 1
60+
61+
def test_huggingface_chat_target_supports_text_only(self, sqlite_instance):
62+
"""
63+
Test that HuggingFaceChatTarget declares text-only support correctly.
64+
"""
65+
target = HuggingFaceChatTarget(
66+
model_id="microsoft/DialoGPT-medium",
67+
use_cuda=False, # Avoid GPU dependency in tests
68+
)
69+
70+
# HuggingFace Chat should only support text (for now)
71+
assert target.input_modality_supported("text")
72+
assert not target.input_modality_supported("image_path")
73+
assert not target.input_modality_supported("audio")
74+
assert not target.supports_multimodal_input()
75+
assert target.supported_input_modalities == ["text"]
76+
assert len(target.supported_input_modalities) == 1
77+
78+
def test_target_identifier_includes_modality_fields(self, sqlite_instance):
79+
"""
80+
Test that target identifiers include modality support information.
81+
"""
82+
# Test multimodal target (OpenAI)
83+
openai_target = OpenAIChatTarget(
84+
endpoint="https://test.openai.azure.com/",
85+
model_name="gpt-4o",
86+
api_key="test",
87+
)
88+
openai_id = openai_target.get_identifier()
89+
assert openai_id.supported_input_modalities is not None
90+
assert "text" in openai_id.supported_input_modalities
91+
assert "image_path" in openai_id.supported_input_modalities
92+
assert openai_id.supports_conversation_history is True
93+
94+
# Test text-only target
95+
text_target = TextTarget()
96+
text_id = text_target.get_identifier()
97+
assert text_id.supported_input_modalities == ["text"]
98+
assert text_id.supports_conversation_history is True
99+
100+
def test_modality_support_detection_differentiates_targets(self, sqlite_instance):
101+
"""
102+
Test that modality support detection can differentiate between target types.
103+
"""
104+
openai_target = OpenAIChatTarget(
105+
endpoint="https://test.openai.azure.com/",
106+
model_name="gpt-4o",
107+
api_key="test",
108+
)
109+
text_target = TextTarget()
110+
111+
# Test differentiation
112+
assert openai_target.supports_multimodal_input()
113+
assert not text_target.supports_multimodal_input()
114+
115+
# Test that they support different modality sets
116+
assert set(openai_target.supported_input_modalities) != set(text_target.supported_input_modalities)
117+
118+
# Both should support text
119+
assert openai_target.input_modality_supported("text")
120+
assert text_target.input_modality_supported("text")
121+
122+
# Only OpenAI should support image_path
123+
assert openai_target.input_modality_supported("image_path")
124+
assert not text_target.input_modality_supported("image_path")
125+
126+
def test_modality_support_properties_are_immutable(self, sqlite_instance):
127+
"""
128+
Test that modality support properties return copies and are not directly modifiable.
129+
"""
130+
target = OpenAIChatTarget(
131+
endpoint="https://test.openai.azure.com/",
132+
model_name="gpt-4o",
133+
api_key="test",
134+
)
135+
136+
# Get the list and modify it
137+
modalities = target.supported_input_modalities
138+
original_length = len(modalities)
139+
modalities.append("fake_modality")
140+
141+
# Verify the target's actual list wasn't modified
142+
assert len(target.supported_input_modalities) == original_length
143+
assert "fake_modality" not in target.supported_input_modalities

0 commit comments

Comments
 (0)